[docx] split commit for file 400

Signed-off-by: Ari Archer <ari.web.xyz@gmail.com>
This commit is contained in:
Ari Archer 2024-01-23 19:04:21 +02:00
parent 6c4587804f
commit 3c586de8ec
No known key found for this signature in database
GPG Key ID: A50D5B4B599AF8A2
400 changed files with 0 additions and 21831 deletions

View File

@ -1,61 +0,0 @@
/* ----------------------------------------------------------------------------
* This file was automatically generated by SWIG (http://www.swig.org).
* Version 4.0.2
*
* Do not make changes to this file unless you know what you are doing--modify
* the SWIG interface file instead.
* ----------------------------------------------------------------------------- */
package com.twitter.ann.faiss;
public class doubleArray {
private transient long swigCPtr;
protected transient boolean swigCMemOwn;
protected doubleArray(long cPtr, boolean cMemoryOwn) {
swigCMemOwn = cMemoryOwn;
swigCPtr = cPtr;
}
protected static long getCPtr(doubleArray obj) {
return (obj == null) ? 0 : obj.swigCPtr;
}
@SuppressWarnings("deprecation")
protected void finalize() {
delete();
}
public synchronized void delete() {
if (swigCPtr != 0) {
if (swigCMemOwn) {
swigCMemOwn = false;
swigfaissJNI.delete_doubleArray(swigCPtr);
}
swigCPtr = 0;
}
}
public doubleArray(int nelements) {
this(swigfaissJNI.new_doubleArray(nelements), true);
}
public double getitem(int index) {
return swigfaissJNI.doubleArray_getitem(swigCPtr, this, index);
}
public void setitem(int index, double value) {
swigfaissJNI.doubleArray_setitem(swigCPtr, this, index, value);
}
public SWIGTYPE_p_double cast() {
long cPtr = swigfaissJNI.doubleArray_cast(swigCPtr, this);
return (cPtr == 0) ? null : new SWIGTYPE_p_double(cPtr, false);
}
public static doubleArray frompointer(SWIGTYPE_p_double t) {
long cPtr = swigfaissJNI.doubleArray_frompointer(SWIGTYPE_p_double.getCPtr(t));
return (cPtr == 0) ? null : new doubleArray(cPtr, false);
}
}

View File

@ -1,61 +0,0 @@
/* ----------------------------------------------------------------------------
* This file was automatically generated by SWIG (http://www.swig.org).
* Version 4.0.2
*
* Do not make changes to this file unless you know what you are doing--modify
* the SWIG interface file instead.
* ----------------------------------------------------------------------------- */
package com.twitter.ann.faiss;
public class floatArray {
private transient long swigCPtr;
protected transient boolean swigCMemOwn;
protected floatArray(long cPtr, boolean cMemoryOwn) {
swigCMemOwn = cMemoryOwn;
swigCPtr = cPtr;
}
protected static long getCPtr(floatArray obj) {
return (obj == null) ? 0 : obj.swigCPtr;
}
@SuppressWarnings("deprecation")
protected void finalize() {
delete();
}
public synchronized void delete() {
if (swigCPtr != 0) {
if (swigCMemOwn) {
swigCMemOwn = false;
swigfaissJNI.delete_floatArray(swigCPtr);
}
swigCPtr = 0;
}
}
public floatArray(int nelements) {
this(swigfaissJNI.new_floatArray(nelements), true);
}
public float getitem(int index) {
return swigfaissJNI.floatArray_getitem(swigCPtr, this, index);
}
public void setitem(int index, float value) {
swigfaissJNI.floatArray_setitem(swigCPtr, this, index, value);
}
public SWIGTYPE_p_float cast() {
long cPtr = swigfaissJNI.floatArray_cast(swigCPtr, this);
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
}
public static floatArray frompointer(SWIGTYPE_p_float t) {
long cPtr = swigfaissJNI.floatArray_frompointer(SWIGTYPE_p_float.getCPtr(t));
return (cPtr == 0) ? null : new floatArray(cPtr, false);
}
}

View File

@ -1,133 +0,0 @@
/* ----------------------------------------------------------------------------
* This file was automatically generated by SWIG (http://www.swig.org).
* Version 4.0.2
*
* Do not make changes to this file unless you know what you are doing--modify
* the SWIG interface file instead.
* ----------------------------------------------------------------------------- */
package com.twitter.ann.faiss;
public class float_maxheap_array_t {
private transient long swigCPtr;
protected transient boolean swigCMemOwn;
protected float_maxheap_array_t(long cPtr, boolean cMemoryOwn) {
swigCMemOwn = cMemoryOwn;
swigCPtr = cPtr;
}
protected static long getCPtr(float_maxheap_array_t obj) {
return (obj == null) ? 0 : obj.swigCPtr;
}
@SuppressWarnings("deprecation")
protected void finalize() {
delete();
}
public synchronized void delete() {
if (swigCPtr != 0) {
if (swigCMemOwn) {
swigCMemOwn = false;
swigfaissJNI.delete_float_maxheap_array_t(swigCPtr);
}
swigCPtr = 0;
}
}
public void setNh(long value) {
swigfaissJNI.float_maxheap_array_t_nh_set(swigCPtr, this, value);
}
public long getNh() {
return swigfaissJNI.float_maxheap_array_t_nh_get(swigCPtr, this);
}
public void setK(long value) {
swigfaissJNI.float_maxheap_array_t_k_set(swigCPtr, this, value);
}
public long getK() {
return swigfaissJNI.float_maxheap_array_t_k_get(swigCPtr, this);
}
public void setIds(LongVector value) {
swigfaissJNI.float_maxheap_array_t_ids_set(swigCPtr, this, SWIGTYPE_p_long_long.getCPtr(value.data()), value);
}
public LongVector getIds() {
return new LongVector(swigfaissJNI.float_maxheap_array_t_ids_get(swigCPtr, this), false);
}
public void setVal(SWIGTYPE_p_float value) {
swigfaissJNI.float_maxheap_array_t_val_set(swigCPtr, this, SWIGTYPE_p_float.getCPtr(value));
}
public SWIGTYPE_p_float getVal() {
long cPtr = swigfaissJNI.float_maxheap_array_t_val_get(swigCPtr, this);
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
}
public SWIGTYPE_p_float get_val(long key) {
long cPtr = swigfaissJNI.float_maxheap_array_t_get_val(swigCPtr, this, key);
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
}
public LongVector get_ids(long key) {
return new LongVector(swigfaissJNI.float_maxheap_array_t_get_ids(swigCPtr, this, key), false);
}
public void heapify() {
swigfaissJNI.float_maxheap_array_t_heapify(swigCPtr, this);
}
public void addn(long nj, SWIGTYPE_p_float vin, long j0, long i0, long ni) {
swigfaissJNI.float_maxheap_array_t_addn__SWIG_0(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), j0, i0, ni);
}
public void addn(long nj, SWIGTYPE_p_float vin, long j0, long i0) {
swigfaissJNI.float_maxheap_array_t_addn__SWIG_1(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), j0, i0);
}
public void addn(long nj, SWIGTYPE_p_float vin, long j0) {
swigfaissJNI.float_maxheap_array_t_addn__SWIG_2(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), j0);
}
public void addn(long nj, SWIGTYPE_p_float vin) {
swigfaissJNI.float_maxheap_array_t_addn__SWIG_3(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin));
}
public void addn_with_ids(long nj, SWIGTYPE_p_float vin, LongVector id_in, long id_stride, long i0, long ni) {
swigfaissJNI.float_maxheap_array_t_addn_with_ids__SWIG_0(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride, i0, ni);
}
public void addn_with_ids(long nj, SWIGTYPE_p_float vin, LongVector id_in, long id_stride, long i0) {
swigfaissJNI.float_maxheap_array_t_addn_with_ids__SWIG_1(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride, i0);
}
public void addn_with_ids(long nj, SWIGTYPE_p_float vin, LongVector id_in, long id_stride) {
swigfaissJNI.float_maxheap_array_t_addn_with_ids__SWIG_2(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride);
}
public void addn_with_ids(long nj, SWIGTYPE_p_float vin, LongVector id_in) {
swigfaissJNI.float_maxheap_array_t_addn_with_ids__SWIG_3(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in);
}
public void addn_with_ids(long nj, SWIGTYPE_p_float vin) {
swigfaissJNI.float_maxheap_array_t_addn_with_ids__SWIG_4(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin));
}
public void reorder() {
swigfaissJNI.float_maxheap_array_t_reorder(swigCPtr, this);
}
public void per_line_extrema(SWIGTYPE_p_float vals_out, LongVector idx_out) {
swigfaissJNI.float_maxheap_array_t_per_line_extrema(swigCPtr, this, SWIGTYPE_p_float.getCPtr(vals_out), SWIGTYPE_p_long_long.getCPtr(idx_out.data()), idx_out);
}
public float_maxheap_array_t() {
this(swigfaissJNI.new_float_maxheap_array_t(), true);
}
}

View File

@ -1,133 +0,0 @@
/* ----------------------------------------------------------------------------
* This file was automatically generated by SWIG (http://www.swig.org).
* Version 4.0.2
*
* Do not make changes to this file unless you know what you are doing--modify
* the SWIG interface file instead.
* ----------------------------------------------------------------------------- */
package com.twitter.ann.faiss;
public class float_minheap_array_t {
private transient long swigCPtr;
protected transient boolean swigCMemOwn;
protected float_minheap_array_t(long cPtr, boolean cMemoryOwn) {
swigCMemOwn = cMemoryOwn;
swigCPtr = cPtr;
}
protected static long getCPtr(float_minheap_array_t obj) {
return (obj == null) ? 0 : obj.swigCPtr;
}
@SuppressWarnings("deprecation")
protected void finalize() {
delete();
}
public synchronized void delete() {
if (swigCPtr != 0) {
if (swigCMemOwn) {
swigCMemOwn = false;
swigfaissJNI.delete_float_minheap_array_t(swigCPtr);
}
swigCPtr = 0;
}
}
public void setNh(long value) {
swigfaissJNI.float_minheap_array_t_nh_set(swigCPtr, this, value);
}
public long getNh() {
return swigfaissJNI.float_minheap_array_t_nh_get(swigCPtr, this);
}
public void setK(long value) {
swigfaissJNI.float_minheap_array_t_k_set(swigCPtr, this, value);
}
public long getK() {
return swigfaissJNI.float_minheap_array_t_k_get(swigCPtr, this);
}
public void setIds(LongVector value) {
swigfaissJNI.float_minheap_array_t_ids_set(swigCPtr, this, SWIGTYPE_p_long_long.getCPtr(value.data()), value);
}
public LongVector getIds() {
return new LongVector(swigfaissJNI.float_minheap_array_t_ids_get(swigCPtr, this), false);
}
public void setVal(SWIGTYPE_p_float value) {
swigfaissJNI.float_minheap_array_t_val_set(swigCPtr, this, SWIGTYPE_p_float.getCPtr(value));
}
public SWIGTYPE_p_float getVal() {
long cPtr = swigfaissJNI.float_minheap_array_t_val_get(swigCPtr, this);
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
}
public SWIGTYPE_p_float get_val(long key) {
long cPtr = swigfaissJNI.float_minheap_array_t_get_val(swigCPtr, this, key);
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
}
public LongVector get_ids(long key) {
return new LongVector(swigfaissJNI.float_minheap_array_t_get_ids(swigCPtr, this, key), false);
}
public void heapify() {
swigfaissJNI.float_minheap_array_t_heapify(swigCPtr, this);
}
public void addn(long nj, SWIGTYPE_p_float vin, long j0, long i0, long ni) {
swigfaissJNI.float_minheap_array_t_addn__SWIG_0(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), j0, i0, ni);
}
public void addn(long nj, SWIGTYPE_p_float vin, long j0, long i0) {
swigfaissJNI.float_minheap_array_t_addn__SWIG_1(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), j0, i0);
}
public void addn(long nj, SWIGTYPE_p_float vin, long j0) {
swigfaissJNI.float_minheap_array_t_addn__SWIG_2(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), j0);
}
public void addn(long nj, SWIGTYPE_p_float vin) {
swigfaissJNI.float_minheap_array_t_addn__SWIG_3(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin));
}
public void addn_with_ids(long nj, SWIGTYPE_p_float vin, LongVector id_in, long id_stride, long i0, long ni) {
swigfaissJNI.float_minheap_array_t_addn_with_ids__SWIG_0(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride, i0, ni);
}
public void addn_with_ids(long nj, SWIGTYPE_p_float vin, LongVector id_in, long id_stride, long i0) {
swigfaissJNI.float_minheap_array_t_addn_with_ids__SWIG_1(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride, i0);
}
public void addn_with_ids(long nj, SWIGTYPE_p_float vin, LongVector id_in, long id_stride) {
swigfaissJNI.float_minheap_array_t_addn_with_ids__SWIG_2(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride);
}
public void addn_with_ids(long nj, SWIGTYPE_p_float vin, LongVector id_in) {
swigfaissJNI.float_minheap_array_t_addn_with_ids__SWIG_3(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in);
}
public void addn_with_ids(long nj, SWIGTYPE_p_float vin) {
swigfaissJNI.float_minheap_array_t_addn_with_ids__SWIG_4(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin));
}
public void reorder() {
swigfaissJNI.float_minheap_array_t_reorder(swigCPtr, this);
}
public void per_line_extrema(SWIGTYPE_p_float vals_out, LongVector idx_out) {
swigfaissJNI.float_minheap_array_t_per_line_extrema(swigCPtr, this, SWIGTYPE_p_float.getCPtr(vals_out), SWIGTYPE_p_long_long.getCPtr(idx_out.data()), idx_out);
}
public float_minheap_array_t() {
this(swigfaissJNI.new_float_minheap_array_t(), true);
}
}

View File

@ -1,61 +0,0 @@
/* ----------------------------------------------------------------------------
* This file was automatically generated by SWIG (http://www.swig.org).
* Version 4.0.2
*
* Do not make changes to this file unless you know what you are doing--modify
* the SWIG interface file instead.
* ----------------------------------------------------------------------------- */
package com.twitter.ann.faiss;
public class intArray {
private transient long swigCPtr;
protected transient boolean swigCMemOwn;
protected intArray(long cPtr, boolean cMemoryOwn) {
swigCMemOwn = cMemoryOwn;
swigCPtr = cPtr;
}
protected static long getCPtr(intArray obj) {
return (obj == null) ? 0 : obj.swigCPtr;
}
@SuppressWarnings("deprecation")
protected void finalize() {
delete();
}
public synchronized void delete() {
if (swigCPtr != 0) {
if (swigCMemOwn) {
swigCMemOwn = false;
swigfaissJNI.delete_intArray(swigCPtr);
}
swigCPtr = 0;
}
}
public intArray(int nelements) {
this(swigfaissJNI.new_intArray(nelements), true);
}
public int getitem(int index) {
return swigfaissJNI.intArray_getitem(swigCPtr, this, index);
}
public void setitem(int index, int value) {
swigfaissJNI.intArray_setitem(swigCPtr, this, index, value);
}
public SWIGTYPE_p_int cast() {
long cPtr = swigfaissJNI.intArray_cast(swigCPtr, this);
return (cPtr == 0) ? null : new SWIGTYPE_p_int(cPtr, false);
}
public static intArray frompointer(SWIGTYPE_p_int t) {
long cPtr = swigfaissJNI.intArray_frompointer(SWIGTYPE_p_int.getCPtr(t));
return (cPtr == 0) ? null : new intArray(cPtr, false);
}
}

View File

@ -1,133 +0,0 @@
/* ----------------------------------------------------------------------------
* This file was automatically generated by SWIG (http://www.swig.org).
* Version 4.0.2
*
* Do not make changes to this file unless you know what you are doing--modify
* the SWIG interface file instead.
* ----------------------------------------------------------------------------- */
package com.twitter.ann.faiss;
public class int_maxheap_array_t {
private transient long swigCPtr;
protected transient boolean swigCMemOwn;
protected int_maxheap_array_t(long cPtr, boolean cMemoryOwn) {
swigCMemOwn = cMemoryOwn;
swigCPtr = cPtr;
}
protected static long getCPtr(int_maxheap_array_t obj) {
return (obj == null) ? 0 : obj.swigCPtr;
}
@SuppressWarnings("deprecation")
protected void finalize() {
delete();
}
public synchronized void delete() {
if (swigCPtr != 0) {
if (swigCMemOwn) {
swigCMemOwn = false;
swigfaissJNI.delete_int_maxheap_array_t(swigCPtr);
}
swigCPtr = 0;
}
}
public void setNh(long value) {
swigfaissJNI.int_maxheap_array_t_nh_set(swigCPtr, this, value);
}
public long getNh() {
return swigfaissJNI.int_maxheap_array_t_nh_get(swigCPtr, this);
}
public void setK(long value) {
swigfaissJNI.int_maxheap_array_t_k_set(swigCPtr, this, value);
}
public long getK() {
return swigfaissJNI.int_maxheap_array_t_k_get(swigCPtr, this);
}
public void setIds(LongVector value) {
swigfaissJNI.int_maxheap_array_t_ids_set(swigCPtr, this, SWIGTYPE_p_long_long.getCPtr(value.data()), value);
}
public LongVector getIds() {
return new LongVector(swigfaissJNI.int_maxheap_array_t_ids_get(swigCPtr, this), false);
}
public void setVal(SWIGTYPE_p_int value) {
swigfaissJNI.int_maxheap_array_t_val_set(swigCPtr, this, SWIGTYPE_p_int.getCPtr(value));
}
public SWIGTYPE_p_int getVal() {
long cPtr = swigfaissJNI.int_maxheap_array_t_val_get(swigCPtr, this);
return (cPtr == 0) ? null : new SWIGTYPE_p_int(cPtr, false);
}
public SWIGTYPE_p_int get_val(long key) {
long cPtr = swigfaissJNI.int_maxheap_array_t_get_val(swigCPtr, this, key);
return (cPtr == 0) ? null : new SWIGTYPE_p_int(cPtr, false);
}
public LongVector get_ids(long key) {
return new LongVector(swigfaissJNI.int_maxheap_array_t_get_ids(swigCPtr, this, key), false);
}
public void heapify() {
swigfaissJNI.int_maxheap_array_t_heapify(swigCPtr, this);
}
public void addn(long nj, SWIGTYPE_p_int vin, long j0, long i0, long ni) {
swigfaissJNI.int_maxheap_array_t_addn__SWIG_0(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), j0, i0, ni);
}
public void addn(long nj, SWIGTYPE_p_int vin, long j0, long i0) {
swigfaissJNI.int_maxheap_array_t_addn__SWIG_1(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), j0, i0);
}
public void addn(long nj, SWIGTYPE_p_int vin, long j0) {
swigfaissJNI.int_maxheap_array_t_addn__SWIG_2(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), j0);
}
public void addn(long nj, SWIGTYPE_p_int vin) {
swigfaissJNI.int_maxheap_array_t_addn__SWIG_3(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin));
}
public void addn_with_ids(long nj, SWIGTYPE_p_int vin, LongVector id_in, long id_stride, long i0, long ni) {
swigfaissJNI.int_maxheap_array_t_addn_with_ids__SWIG_0(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride, i0, ni);
}
public void addn_with_ids(long nj, SWIGTYPE_p_int vin, LongVector id_in, long id_stride, long i0) {
swigfaissJNI.int_maxheap_array_t_addn_with_ids__SWIG_1(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride, i0);
}
public void addn_with_ids(long nj, SWIGTYPE_p_int vin, LongVector id_in, long id_stride) {
swigfaissJNI.int_maxheap_array_t_addn_with_ids__SWIG_2(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride);
}
public void addn_with_ids(long nj, SWIGTYPE_p_int vin, LongVector id_in) {
swigfaissJNI.int_maxheap_array_t_addn_with_ids__SWIG_3(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in);
}
public void addn_with_ids(long nj, SWIGTYPE_p_int vin) {
swigfaissJNI.int_maxheap_array_t_addn_with_ids__SWIG_4(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin));
}
public void reorder() {
swigfaissJNI.int_maxheap_array_t_reorder(swigCPtr, this);
}
public void per_line_extrema(SWIGTYPE_p_int vals_out, LongVector idx_out) {
swigfaissJNI.int_maxheap_array_t_per_line_extrema(swigCPtr, this, SWIGTYPE_p_int.getCPtr(vals_out), SWIGTYPE_p_long_long.getCPtr(idx_out.data()), idx_out);
}
public int_maxheap_array_t() {
this(swigfaissJNI.new_int_maxheap_array_t(), true);
}
}

View File

@ -1,133 +0,0 @@
/* ----------------------------------------------------------------------------
* This file was automatically generated by SWIG (http://www.swig.org).
* Version 4.0.2
*
* Do not make changes to this file unless you know what you are doing--modify
* the SWIG interface file instead.
* ----------------------------------------------------------------------------- */
package com.twitter.ann.faiss;
public class int_minheap_array_t {
private transient long swigCPtr;
protected transient boolean swigCMemOwn;
protected int_minheap_array_t(long cPtr, boolean cMemoryOwn) {
swigCMemOwn = cMemoryOwn;
swigCPtr = cPtr;
}
protected static long getCPtr(int_minheap_array_t obj) {
return (obj == null) ? 0 : obj.swigCPtr;
}
@SuppressWarnings("deprecation")
protected void finalize() {
delete();
}
public synchronized void delete() {
if (swigCPtr != 0) {
if (swigCMemOwn) {
swigCMemOwn = false;
swigfaissJNI.delete_int_minheap_array_t(swigCPtr);
}
swigCPtr = 0;
}
}
public void setNh(long value) {
swigfaissJNI.int_minheap_array_t_nh_set(swigCPtr, this, value);
}
public long getNh() {
return swigfaissJNI.int_minheap_array_t_nh_get(swigCPtr, this);
}
public void setK(long value) {
swigfaissJNI.int_minheap_array_t_k_set(swigCPtr, this, value);
}
public long getK() {
return swigfaissJNI.int_minheap_array_t_k_get(swigCPtr, this);
}
public void setIds(LongVector value) {
swigfaissJNI.int_minheap_array_t_ids_set(swigCPtr, this, SWIGTYPE_p_long_long.getCPtr(value.data()), value);
}
public LongVector getIds() {
return new LongVector(swigfaissJNI.int_minheap_array_t_ids_get(swigCPtr, this), false);
}
public void setVal(SWIGTYPE_p_int value) {
swigfaissJNI.int_minheap_array_t_val_set(swigCPtr, this, SWIGTYPE_p_int.getCPtr(value));
}
public SWIGTYPE_p_int getVal() {
long cPtr = swigfaissJNI.int_minheap_array_t_val_get(swigCPtr, this);
return (cPtr == 0) ? null : new SWIGTYPE_p_int(cPtr, false);
}
public SWIGTYPE_p_int get_val(long key) {
long cPtr = swigfaissJNI.int_minheap_array_t_get_val(swigCPtr, this, key);
return (cPtr == 0) ? null : new SWIGTYPE_p_int(cPtr, false);
}
public LongVector get_ids(long key) {
return new LongVector(swigfaissJNI.int_minheap_array_t_get_ids(swigCPtr, this, key), false);
}
public void heapify() {
swigfaissJNI.int_minheap_array_t_heapify(swigCPtr, this);
}
public void addn(long nj, SWIGTYPE_p_int vin, long j0, long i0, long ni) {
swigfaissJNI.int_minheap_array_t_addn__SWIG_0(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), j0, i0, ni);
}
public void addn(long nj, SWIGTYPE_p_int vin, long j0, long i0) {
swigfaissJNI.int_minheap_array_t_addn__SWIG_1(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), j0, i0);
}
public void addn(long nj, SWIGTYPE_p_int vin, long j0) {
swigfaissJNI.int_minheap_array_t_addn__SWIG_2(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), j0);
}
public void addn(long nj, SWIGTYPE_p_int vin) {
swigfaissJNI.int_minheap_array_t_addn__SWIG_3(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin));
}
public void addn_with_ids(long nj, SWIGTYPE_p_int vin, LongVector id_in, long id_stride, long i0, long ni) {
swigfaissJNI.int_minheap_array_t_addn_with_ids__SWIG_0(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride, i0, ni);
}
public void addn_with_ids(long nj, SWIGTYPE_p_int vin, LongVector id_in, long id_stride, long i0) {
swigfaissJNI.int_minheap_array_t_addn_with_ids__SWIG_1(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride, i0);
}
public void addn_with_ids(long nj, SWIGTYPE_p_int vin, LongVector id_in, long id_stride) {
swigfaissJNI.int_minheap_array_t_addn_with_ids__SWIG_2(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride);
}
public void addn_with_ids(long nj, SWIGTYPE_p_int vin, LongVector id_in) {
swigfaissJNI.int_minheap_array_t_addn_with_ids__SWIG_3(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in);
}
public void addn_with_ids(long nj, SWIGTYPE_p_int vin) {
swigfaissJNI.int_minheap_array_t_addn_with_ids__SWIG_4(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin));
}
public void reorder() {
swigfaissJNI.int_minheap_array_t_reorder(swigCPtr, this);
}
public void per_line_extrema(SWIGTYPE_p_int vals_out, LongVector idx_out) {
swigfaissJNI.int_minheap_array_t_per_line_extrema(swigCPtr, this, SWIGTYPE_p_int.getCPtr(vals_out), SWIGTYPE_p_long_long.getCPtr(idx_out.data()), idx_out);
}
public int_minheap_array_t() {
this(swigfaissJNI.new_int_minheap_array_t(), true);
}
}

View File

@ -1,61 +0,0 @@
/* ----------------------------------------------------------------------------
* This file was automatically generated by SWIG (http://www.swig.org).
* Version 4.0.2
*
* Do not make changes to this file unless you know what you are doing--modify
* the SWIG interface file instead.
* ----------------------------------------------------------------------------- */
package com.twitter.ann.faiss;
public class longArray {
private transient long swigCPtr;
protected transient boolean swigCMemOwn;
protected longArray(long cPtr, boolean cMemoryOwn) {
swigCMemOwn = cMemoryOwn;
swigCPtr = cPtr;
}
protected static long getCPtr(longArray obj) {
return (obj == null) ? 0 : obj.swigCPtr;
}
@SuppressWarnings("deprecation")
protected void finalize() {
delete();
}
public synchronized void delete() {
if (swigCPtr != 0) {
if (swigCMemOwn) {
swigCMemOwn = false;
swigfaissJNI.delete_longArray(swigCPtr);
}
swigCPtr = 0;
}
}
public longArray(int nelements) {
this(swigfaissJNI.new_longArray(nelements), true);
}
public long getitem(int index) {
return swigfaissJNI.longArray_getitem(swigCPtr, this, index);
}
public void setitem(int index, long value) {
swigfaissJNI.longArray_setitem(swigCPtr, this, index, value);
}
public SWIGTYPE_p_long_long cast() {
long cPtr = swigfaissJNI.longArray_cast(swigCPtr, this);
return (cPtr == 0) ? null : new SWIGTYPE_p_long_long(cPtr, false);
}
public static longArray frompointer(SWIGTYPE_p_long_long t) {
long cPtr = swigfaissJNI.longArray_frompointer(SWIGTYPE_p_long_long.getCPtr(t));
return (cPtr == 0) ? null : new longArray(cPtr, false);
}
}

View File

@ -1,575 +0,0 @@
/* ----------------------------------------------------------------------------
* This file was automatically generated by SWIG (http://www.swig.org).
* Version 4.0.2
*
* Do not make changes to this file unless you know what you are doing--modify
* the SWIG interface file instead.
* ----------------------------------------------------------------------------- */
package com.twitter.ann.faiss;
public class swigfaiss implements swigfaissConstants {
public static void bitvec_print(SWIGTYPE_p_unsigned_char b, long d) {
swigfaissJNI.bitvec_print(SWIGTYPE_p_unsigned_char.getCPtr(b), d);
}
public static void fvecs2bitvecs(SWIGTYPE_p_float x, SWIGTYPE_p_unsigned_char b, long d, long n) {
swigfaissJNI.fvecs2bitvecs(SWIGTYPE_p_float.getCPtr(x), SWIGTYPE_p_unsigned_char.getCPtr(b), d, n);
}
public static void bitvecs2fvecs(SWIGTYPE_p_unsigned_char b, SWIGTYPE_p_float x, long d, long n) {
swigfaissJNI.bitvecs2fvecs(SWIGTYPE_p_unsigned_char.getCPtr(b), SWIGTYPE_p_float.getCPtr(x), d, n);
}
public static void fvec2bitvec(SWIGTYPE_p_float x, SWIGTYPE_p_unsigned_char b, long d) {
swigfaissJNI.fvec2bitvec(SWIGTYPE_p_float.getCPtr(x), SWIGTYPE_p_unsigned_char.getCPtr(b), d);
}
public static void bitvec_shuffle(long n, long da, long db, SWIGTYPE_p_int order, SWIGTYPE_p_unsigned_char a, SWIGTYPE_p_unsigned_char b) {
swigfaissJNI.bitvec_shuffle(n, da, db, SWIGTYPE_p_int.getCPtr(order), SWIGTYPE_p_unsigned_char.getCPtr(a), SWIGTYPE_p_unsigned_char.getCPtr(b));
}
public static void setHamming_batch_size(long value) {
swigfaissJNI.hamming_batch_size_set(value);
}
public static long getHamming_batch_size() {
return swigfaissJNI.hamming_batch_size_get();
}
public static int popcount64(long x) {
return swigfaissJNI.popcount64(x);
}
public static void hammings(SWIGTYPE_p_unsigned_char a, SWIGTYPE_p_unsigned_char b, long na, long nb, long nbytespercode, SWIGTYPE_p_int dis) {
swigfaissJNI.hammings(SWIGTYPE_p_unsigned_char.getCPtr(a), SWIGTYPE_p_unsigned_char.getCPtr(b), na, nb, nbytespercode, SWIGTYPE_p_int.getCPtr(dis));
}
public static void hammings_knn_hc(SWIGTYPE_p_faiss__HeapArrayT_faiss__CMaxT_int_int64_t_t_t ha, SWIGTYPE_p_unsigned_char a, SWIGTYPE_p_unsigned_char b, long nb, long ncodes, int ordered) {
swigfaissJNI.hammings_knn_hc(SWIGTYPE_p_faiss__HeapArrayT_faiss__CMaxT_int_int64_t_t_t.getCPtr(ha), SWIGTYPE_p_unsigned_char.getCPtr(a), SWIGTYPE_p_unsigned_char.getCPtr(b), nb, ncodes, ordered);
}
public static void hammings_knn(SWIGTYPE_p_faiss__HeapArrayT_faiss__CMaxT_int_int64_t_t_t ha, SWIGTYPE_p_unsigned_char a, SWIGTYPE_p_unsigned_char b, long nb, long ncodes, int ordered) {
swigfaissJNI.hammings_knn(SWIGTYPE_p_faiss__HeapArrayT_faiss__CMaxT_int_int64_t_t_t.getCPtr(ha), SWIGTYPE_p_unsigned_char.getCPtr(a), SWIGTYPE_p_unsigned_char.getCPtr(b), nb, ncodes, ordered);
}
public static void hammings_knn_mc(SWIGTYPE_p_unsigned_char a, SWIGTYPE_p_unsigned_char b, long na, long nb, long k, long ncodes, SWIGTYPE_p_int distances, LongVector labels) {
swigfaissJNI.hammings_knn_mc(SWIGTYPE_p_unsigned_char.getCPtr(a), SWIGTYPE_p_unsigned_char.getCPtr(b), na, nb, k, ncodes, SWIGTYPE_p_int.getCPtr(distances), SWIGTYPE_p_long_long.getCPtr(labels.data()), labels);
}
public static void hamming_range_search(SWIGTYPE_p_unsigned_char a, SWIGTYPE_p_unsigned_char b, long na, long nb, int radius, long ncodes, RangeSearchResult result) {
swigfaissJNI.hamming_range_search(SWIGTYPE_p_unsigned_char.getCPtr(a), SWIGTYPE_p_unsigned_char.getCPtr(b), na, nb, radius, ncodes, RangeSearchResult.getCPtr(result), result);
}
public static void hamming_count_thres(SWIGTYPE_p_unsigned_char bs1, SWIGTYPE_p_unsigned_char bs2, long n1, long n2, int ht, long ncodes, SWIGTYPE_p_unsigned_long nptr) {
swigfaissJNI.hamming_count_thres(SWIGTYPE_p_unsigned_char.getCPtr(bs1), SWIGTYPE_p_unsigned_char.getCPtr(bs2), n1, n2, ht, ncodes, SWIGTYPE_p_unsigned_long.getCPtr(nptr));
}
public static long match_hamming_thres(SWIGTYPE_p_unsigned_char bs1, SWIGTYPE_p_unsigned_char bs2, long n1, long n2, int ht, long ncodes, LongVector idx, SWIGTYPE_p_int dis) {
return swigfaissJNI.match_hamming_thres(SWIGTYPE_p_unsigned_char.getCPtr(bs1), SWIGTYPE_p_unsigned_char.getCPtr(bs2), n1, n2, ht, ncodes, SWIGTYPE_p_long_long.getCPtr(idx.data()), idx, SWIGTYPE_p_int.getCPtr(dis));
}
public static void crosshamming_count_thres(SWIGTYPE_p_unsigned_char dbs, long n, int ht, long ncodes, SWIGTYPE_p_unsigned_long nptr) {
swigfaissJNI.crosshamming_count_thres(SWIGTYPE_p_unsigned_char.getCPtr(dbs), n, ht, ncodes, SWIGTYPE_p_unsigned_long.getCPtr(nptr));
}
public static int get_num_gpus() {
return swigfaissJNI.get_num_gpus();
}
public static String get_compile_options() {
return swigfaissJNI.get_compile_options();
}
public static double getmillisecs() {
return swigfaissJNI.getmillisecs();
}
public static long get_mem_usage_kb() {
return swigfaissJNI.get_mem_usage_kb();
}
public static long get_cycles() {
return swigfaissJNI.get_cycles();
}
public static void fvec_madd(long n, SWIGTYPE_p_float a, float bf, SWIGTYPE_p_float b, SWIGTYPE_p_float c) {
swigfaissJNI.fvec_madd(n, SWIGTYPE_p_float.getCPtr(a), bf, SWIGTYPE_p_float.getCPtr(b), SWIGTYPE_p_float.getCPtr(c));
}
public static int fvec_madd_and_argmin(long n, SWIGTYPE_p_float a, float bf, SWIGTYPE_p_float b, SWIGTYPE_p_float c) {
return swigfaissJNI.fvec_madd_and_argmin(n, SWIGTYPE_p_float.getCPtr(a), bf, SWIGTYPE_p_float.getCPtr(b), SWIGTYPE_p_float.getCPtr(c));
}
public static void reflection(SWIGTYPE_p_float u, SWIGTYPE_p_float x, long n, long d, long nu) {
swigfaissJNI.reflection(SWIGTYPE_p_float.getCPtr(u), SWIGTYPE_p_float.getCPtr(x), n, d, nu);
}
public static void matrix_qr(int m, int n, SWIGTYPE_p_float a) {
swigfaissJNI.matrix_qr(m, n, SWIGTYPE_p_float.getCPtr(a));
}
public static void ranklist_handle_ties(int k, LongVector idx, SWIGTYPE_p_float dis) {
swigfaissJNI.ranklist_handle_ties(k, SWIGTYPE_p_long_long.getCPtr(idx.data()), idx, SWIGTYPE_p_float.getCPtr(dis));
}
public static long ranklist_intersection_size(long k1, LongVector v1, long k2, LongVector v2) {
return swigfaissJNI.ranklist_intersection_size(k1, SWIGTYPE_p_long_long.getCPtr(v1.data()), v1, k2, SWIGTYPE_p_long_long.getCPtr(v2.data()), v2);
}
public static long merge_result_table_with(long n, long k, LongVector I0, SWIGTYPE_p_float D0, LongVector I1, SWIGTYPE_p_float D1, boolean keep_min, long translation) {
return swigfaissJNI.merge_result_table_with__SWIG_0(n, k, SWIGTYPE_p_long_long.getCPtr(I0.data()), I0, SWIGTYPE_p_float.getCPtr(D0), SWIGTYPE_p_long_long.getCPtr(I1.data()), I1, SWIGTYPE_p_float.getCPtr(D1), keep_min, translation);
}
public static long merge_result_table_with(long n, long k, LongVector I0, SWIGTYPE_p_float D0, LongVector I1, SWIGTYPE_p_float D1, boolean keep_min) {
return swigfaissJNI.merge_result_table_with__SWIG_1(n, k, SWIGTYPE_p_long_long.getCPtr(I0.data()), I0, SWIGTYPE_p_float.getCPtr(D0), SWIGTYPE_p_long_long.getCPtr(I1.data()), I1, SWIGTYPE_p_float.getCPtr(D1), keep_min);
}
public static long merge_result_table_with(long n, long k, LongVector I0, SWIGTYPE_p_float D0, LongVector I1, SWIGTYPE_p_float D1) {
return swigfaissJNI.merge_result_table_with__SWIG_2(n, k, SWIGTYPE_p_long_long.getCPtr(I0.data()), I0, SWIGTYPE_p_float.getCPtr(D0), SWIGTYPE_p_long_long.getCPtr(I1.data()), I1, SWIGTYPE_p_float.getCPtr(D1));
}
public static double imbalance_factor(int n, int k, LongVector assign) {
return swigfaissJNI.imbalance_factor__SWIG_0(n, k, SWIGTYPE_p_long_long.getCPtr(assign.data()), assign);
}
public static double imbalance_factor(int k, SWIGTYPE_p_int hist) {
return swigfaissJNI.imbalance_factor__SWIG_1(k, SWIGTYPE_p_int.getCPtr(hist));
}
public static void fvec_argsort(long n, SWIGTYPE_p_float vals, SWIGTYPE_p_unsigned_long perm) {
swigfaissJNI.fvec_argsort(n, SWIGTYPE_p_float.getCPtr(vals), SWIGTYPE_p_unsigned_long.getCPtr(perm));
}
public static void fvec_argsort_parallel(long n, SWIGTYPE_p_float vals, SWIGTYPE_p_unsigned_long perm) {
swigfaissJNI.fvec_argsort_parallel(n, SWIGTYPE_p_float.getCPtr(vals), SWIGTYPE_p_unsigned_long.getCPtr(perm));
}
public static int ivec_hist(long n, SWIGTYPE_p_int v, int vmax, SWIGTYPE_p_int hist) {
return swigfaissJNI.ivec_hist(n, SWIGTYPE_p_int.getCPtr(v), vmax, SWIGTYPE_p_int.getCPtr(hist));
}
public static void bincode_hist(long n, long nbits, SWIGTYPE_p_unsigned_char codes, SWIGTYPE_p_int hist) {
swigfaissJNI.bincode_hist(n, nbits, SWIGTYPE_p_unsigned_char.getCPtr(codes), SWIGTYPE_p_int.getCPtr(hist));
}
public static long ivec_checksum(long n, SWIGTYPE_p_int a) {
return swigfaissJNI.ivec_checksum(n, SWIGTYPE_p_int.getCPtr(a));
}
public static SWIGTYPE_p_float fvecs_maybe_subsample(long d, SWIGTYPE_p_unsigned_long n, long nmax, SWIGTYPE_p_float x, boolean verbose, long seed) {
long cPtr = swigfaissJNI.fvecs_maybe_subsample__SWIG_0(d, SWIGTYPE_p_unsigned_long.getCPtr(n), nmax, SWIGTYPE_p_float.getCPtr(x), verbose, seed);
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
}
public static SWIGTYPE_p_float fvecs_maybe_subsample(long d, SWIGTYPE_p_unsigned_long n, long nmax, SWIGTYPE_p_float x, boolean verbose) {
long cPtr = swigfaissJNI.fvecs_maybe_subsample__SWIG_1(d, SWIGTYPE_p_unsigned_long.getCPtr(n), nmax, SWIGTYPE_p_float.getCPtr(x), verbose);
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
}
public static SWIGTYPE_p_float fvecs_maybe_subsample(long d, SWIGTYPE_p_unsigned_long n, long nmax, SWIGTYPE_p_float x) {
long cPtr = swigfaissJNI.fvecs_maybe_subsample__SWIG_2(d, SWIGTYPE_p_unsigned_long.getCPtr(n), nmax, SWIGTYPE_p_float.getCPtr(x));
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
}
public static void binary_to_real(long d, SWIGTYPE_p_unsigned_char x_in, SWIGTYPE_p_float x_out) {
swigfaissJNI.binary_to_real(d, SWIGTYPE_p_unsigned_char.getCPtr(x_in), SWIGTYPE_p_float.getCPtr(x_out));
}
public static void real_to_binary(long d, SWIGTYPE_p_float x_in, SWIGTYPE_p_unsigned_char x_out) {
swigfaissJNI.real_to_binary(d, SWIGTYPE_p_float.getCPtr(x_in), SWIGTYPE_p_unsigned_char.getCPtr(x_out));
}
public static long hash_bytes(SWIGTYPE_p_unsigned_char bytes, long n) {
return swigfaissJNI.hash_bytes(SWIGTYPE_p_unsigned_char.getCPtr(bytes), n);
}
public static boolean check_openmp() {
return swigfaissJNI.check_openmp();
}
public static float kmeans_clustering(long d, long n, long k, SWIGTYPE_p_float x, SWIGTYPE_p_float centroids) {
return swigfaissJNI.kmeans_clustering(d, n, k, SWIGTYPE_p_float.getCPtr(x), SWIGTYPE_p_float.getCPtr(centroids));
}
public static void setIndexPQ_stats(IndexPQStats value) {
swigfaissJNI.indexPQ_stats_set(IndexPQStats.getCPtr(value), value);
}
public static IndexPQStats getIndexPQ_stats() {
long cPtr = swigfaissJNI.indexPQ_stats_get();
return (cPtr == 0) ? null : new IndexPQStats(cPtr, false);
}
public static void setIndexIVF_stats(IndexIVFStats value) {
swigfaissJNI.indexIVF_stats_set(IndexIVFStats.getCPtr(value), value);
}
public static IndexIVFStats getIndexIVF_stats() {
long cPtr = swigfaissJNI.indexIVF_stats_get();
return (cPtr == 0) ? null : new IndexIVFStats(cPtr, false);
}
public static short[] getHamdis_tab_ham_bytes() {
return swigfaissJNI.hamdis_tab_ham_bytes_get();
}
public static int generalized_hamming_64(long a) {
return swigfaissJNI.generalized_hamming_64(a);
}
public static void generalized_hammings_knn_hc(SWIGTYPE_p_faiss__HeapArrayT_faiss__CMaxT_int_int64_t_t_t ha, SWIGTYPE_p_unsigned_char a, SWIGTYPE_p_unsigned_char b, long nb, long code_size, int ordered) {
swigfaissJNI.generalized_hammings_knn_hc__SWIG_0(SWIGTYPE_p_faiss__HeapArrayT_faiss__CMaxT_int_int64_t_t_t.getCPtr(ha), SWIGTYPE_p_unsigned_char.getCPtr(a), SWIGTYPE_p_unsigned_char.getCPtr(b), nb, code_size, ordered);
}
public static void generalized_hammings_knn_hc(SWIGTYPE_p_faiss__HeapArrayT_faiss__CMaxT_int_int64_t_t_t ha, SWIGTYPE_p_unsigned_char a, SWIGTYPE_p_unsigned_char b, long nb, long code_size) {
swigfaissJNI.generalized_hammings_knn_hc__SWIG_1(SWIGTYPE_p_faiss__HeapArrayT_faiss__CMaxT_int_int64_t_t_t.getCPtr(ha), SWIGTYPE_p_unsigned_char.getCPtr(a), SWIGTYPE_p_unsigned_char.getCPtr(b), nb, code_size);
}
public static void check_compatible_for_merge(Index index1, Index index2) {
swigfaissJNI.check_compatible_for_merge(Index.getCPtr(index1), index1, Index.getCPtr(index2), index2);
}
public static IndexIVF extract_index_ivf(Index index) {
long cPtr = swigfaissJNI.extract_index_ivf__SWIG_0(Index.getCPtr(index), index);
return (cPtr == 0) ? null : new IndexIVF(cPtr, false);
}
public static IndexIVF try_extract_index_ivf(Index index) {
long cPtr = swigfaissJNI.try_extract_index_ivf__SWIG_0(Index.getCPtr(index), index);
return (cPtr == 0) ? null : new IndexIVF(cPtr, false);
}
public static void merge_into(Index index0, Index index1, boolean shift_ids) {
swigfaissJNI.merge_into(Index.getCPtr(index0), index0, Index.getCPtr(index1), index1, shift_ids);
}
public static void search_centroid(Index index, SWIGTYPE_p_float x, int n, LongVector centroid_ids) {
swigfaissJNI.search_centroid(Index.getCPtr(index), index, SWIGTYPE_p_float.getCPtr(x), n, SWIGTYPE_p_long_long.getCPtr(centroid_ids.data()), centroid_ids);
}
public static void search_and_return_centroids(Index index, long n, SWIGTYPE_p_float xin, int k, SWIGTYPE_p_float distances, LongVector labels, LongVector query_centroid_ids, LongVector result_centroid_ids) {
swigfaissJNI.search_and_return_centroids(Index.getCPtr(index), index, n, SWIGTYPE_p_float.getCPtr(xin), k, SWIGTYPE_p_float.getCPtr(distances), SWIGTYPE_p_long_long.getCPtr(labels.data()), labels, SWIGTYPE_p_long_long.getCPtr(query_centroid_ids.data()), query_centroid_ids, SWIGTYPE_p_long_long.getCPtr(result_centroid_ids.data()), result_centroid_ids);
}
public static ArrayInvertedLists get_invlist_range(Index index, int i0, int i1) {
long cPtr = swigfaissJNI.get_invlist_range(Index.getCPtr(index), index, i0, i1);
return (cPtr == 0) ? null : new ArrayInvertedLists(cPtr, false);
}
public static void set_invlist_range(Index index, int i0, int i1, ArrayInvertedLists src) {
swigfaissJNI.set_invlist_range(Index.getCPtr(index), index, i0, i1, ArrayInvertedLists.getCPtr(src), src);
}
public static void search_with_parameters(Index index, long n, SWIGTYPE_p_float x, long k, SWIGTYPE_p_float distances, LongVector labels, IVFSearchParameters params, SWIGTYPE_p_unsigned_long nb_dis, SWIGTYPE_p_double ms_per_stage) {
swigfaissJNI.search_with_parameters__SWIG_0(Index.getCPtr(index), index, n, SWIGTYPE_p_float.getCPtr(x), k, SWIGTYPE_p_float.getCPtr(distances), SWIGTYPE_p_long_long.getCPtr(labels.data()), labels, IVFSearchParameters.getCPtr(params), params, SWIGTYPE_p_unsigned_long.getCPtr(nb_dis), SWIGTYPE_p_double.getCPtr(ms_per_stage));
}
public static void search_with_parameters(Index index, long n, SWIGTYPE_p_float x, long k, SWIGTYPE_p_float distances, LongVector labels, IVFSearchParameters params, SWIGTYPE_p_unsigned_long nb_dis) {
swigfaissJNI.search_with_parameters__SWIG_1(Index.getCPtr(index), index, n, SWIGTYPE_p_float.getCPtr(x), k, SWIGTYPE_p_float.getCPtr(distances), SWIGTYPE_p_long_long.getCPtr(labels.data()), labels, IVFSearchParameters.getCPtr(params), params, SWIGTYPE_p_unsigned_long.getCPtr(nb_dis));
}
public static void search_with_parameters(Index index, long n, SWIGTYPE_p_float x, long k, SWIGTYPE_p_float distances, LongVector labels, IVFSearchParameters params) {
swigfaissJNI.search_with_parameters__SWIG_2(Index.getCPtr(index), index, n, SWIGTYPE_p_float.getCPtr(x), k, SWIGTYPE_p_float.getCPtr(distances), SWIGTYPE_p_long_long.getCPtr(labels.data()), labels, IVFSearchParameters.getCPtr(params), params);
}
public static void range_search_with_parameters(Index index, long n, SWIGTYPE_p_float x, float radius, RangeSearchResult result, IVFSearchParameters params, SWIGTYPE_p_unsigned_long nb_dis, SWIGTYPE_p_double ms_per_stage) {
swigfaissJNI.range_search_with_parameters__SWIG_0(Index.getCPtr(index), index, n, SWIGTYPE_p_float.getCPtr(x), radius, RangeSearchResult.getCPtr(result), result, IVFSearchParameters.getCPtr(params), params, SWIGTYPE_p_unsigned_long.getCPtr(nb_dis), SWIGTYPE_p_double.getCPtr(ms_per_stage));
}
public static void range_search_with_parameters(Index index, long n, SWIGTYPE_p_float x, float radius, RangeSearchResult result, IVFSearchParameters params, SWIGTYPE_p_unsigned_long nb_dis) {
swigfaissJNI.range_search_with_parameters__SWIG_1(Index.getCPtr(index), index, n, SWIGTYPE_p_float.getCPtr(x), radius, RangeSearchResult.getCPtr(result), result, IVFSearchParameters.getCPtr(params), params, SWIGTYPE_p_unsigned_long.getCPtr(nb_dis));
}
public static void range_search_with_parameters(Index index, long n, SWIGTYPE_p_float x, float radius, RangeSearchResult result, IVFSearchParameters params) {
swigfaissJNI.range_search_with_parameters__SWIG_2(Index.getCPtr(index), index, n, SWIGTYPE_p_float.getCPtr(x), radius, RangeSearchResult.getCPtr(result), result, IVFSearchParameters.getCPtr(params), params);
}
public static void setHnsw_stats(HNSWStats value) {
swigfaissJNI.hnsw_stats_set(HNSWStats.getCPtr(value), value);
}
public static HNSWStats getHnsw_stats() {
long cPtr = swigfaissJNI.hnsw_stats_get();
return (cPtr == 0) ? null : new HNSWStats(cPtr, false);
}
public static void setPrecomputed_table_max_bytes(long value) {
swigfaissJNI.precomputed_table_max_bytes_set(value);
}
public static long getPrecomputed_table_max_bytes() {
return swigfaissJNI.precomputed_table_max_bytes_get();
}
public static void initialize_IVFPQ_precomputed_table(SWIGTYPE_p_int use_precomputed_table, Index quantizer, ProductQuantizer pq, SWIGTYPE_p_AlignedTableT_float_32_t precomputed_table, boolean verbose) {
swigfaissJNI.initialize_IVFPQ_precomputed_table(SWIGTYPE_p_int.getCPtr(use_precomputed_table), Index.getCPtr(quantizer), quantizer, ProductQuantizer.getCPtr(pq), pq, SWIGTYPE_p_AlignedTableT_float_32_t.getCPtr(precomputed_table), verbose);
}
public static void setIndexIVFPQ_stats(IndexIVFPQStats value) {
swigfaissJNI.indexIVFPQ_stats_set(IndexIVFPQStats.getCPtr(value), value);
}
public static IndexIVFPQStats getIndexIVFPQ_stats() {
long cPtr = swigfaissJNI.indexIVFPQ_stats_get();
return (cPtr == 0) ? null : new IndexIVFPQStats(cPtr, false);
}
public static Index downcast_index(Index index) {
long cPtr = swigfaissJNI.downcast_index(Index.getCPtr(index), index);
return (cPtr == 0) ? null : new Index(cPtr, false);
}
public static VectorTransform downcast_VectorTransform(VectorTransform vt) {
long cPtr = swigfaissJNI.downcast_VectorTransform(VectorTransform.getCPtr(vt), vt);
return (cPtr == 0) ? null : new VectorTransform(cPtr, false);
}
public static IndexBinary downcast_IndexBinary(IndexBinary index) {
long cPtr = swigfaissJNI.downcast_IndexBinary(IndexBinary.getCPtr(index), index);
return (cPtr == 0) ? null : new IndexBinary(cPtr, false);
}
public static Index upcast_IndexShards(IndexShards index) {
long cPtr = swigfaissJNI.upcast_IndexShards(IndexShards.getCPtr(index), index);
return (cPtr == 0) ? null : new Index(cPtr, false);
}
public static void write_index(Index idx, String fname) {
swigfaissJNI.write_index__SWIG_0(Index.getCPtr(idx), idx, fname);
}
public static void write_index(Index idx, SWIGTYPE_p_FILE f) {
swigfaissJNI.write_index__SWIG_1(Index.getCPtr(idx), idx, SWIGTYPE_p_FILE.getCPtr(f));
}
public static void write_index(Index idx, SWIGTYPE_p_faiss__IOWriter writer) {
swigfaissJNI.write_index__SWIG_2(Index.getCPtr(idx), idx, SWIGTYPE_p_faiss__IOWriter.getCPtr(writer));
}
public static void write_index_binary(IndexBinary idx, String fname) {
swigfaissJNI.write_index_binary__SWIG_0(IndexBinary.getCPtr(idx), idx, fname);
}
public static void write_index_binary(IndexBinary idx, SWIGTYPE_p_FILE f) {
swigfaissJNI.write_index_binary__SWIG_1(IndexBinary.getCPtr(idx), idx, SWIGTYPE_p_FILE.getCPtr(f));
}
public static void write_index_binary(IndexBinary idx, SWIGTYPE_p_faiss__IOWriter writer) {
swigfaissJNI.write_index_binary__SWIG_2(IndexBinary.getCPtr(idx), idx, SWIGTYPE_p_faiss__IOWriter.getCPtr(writer));
}
public static int getIO_FLAG_READ_ONLY() {
return swigfaissJNI.IO_FLAG_READ_ONLY_get();
}
public static int getIO_FLAG_ONDISK_SAME_DIR() {
return swigfaissJNI.IO_FLAG_ONDISK_SAME_DIR_get();
}
public static int getIO_FLAG_SKIP_IVF_DATA() {
return swigfaissJNI.IO_FLAG_SKIP_IVF_DATA_get();
}
public static int getIO_FLAG_MMAP() {
return swigfaissJNI.IO_FLAG_MMAP_get();
}
public static Index read_index(String fname, int io_flags) {
long cPtr = swigfaissJNI.read_index__SWIG_0(fname, io_flags);
return (cPtr == 0) ? null : new Index(cPtr, true);
}
public static Index read_index(String fname) {
long cPtr = swigfaissJNI.read_index__SWIG_1(fname);
return (cPtr == 0) ? null : new Index(cPtr, true);
}
public static Index read_index(SWIGTYPE_p_FILE f, int io_flags) {
long cPtr = swigfaissJNI.read_index__SWIG_2(SWIGTYPE_p_FILE.getCPtr(f), io_flags);
return (cPtr == 0) ? null : new Index(cPtr, true);
}
public static Index read_index(SWIGTYPE_p_FILE f) {
long cPtr = swigfaissJNI.read_index__SWIG_3(SWIGTYPE_p_FILE.getCPtr(f));
return (cPtr == 0) ? null : new Index(cPtr, true);
}
public static Index read_index(SWIGTYPE_p_faiss__IOReader reader, int io_flags) {
long cPtr = swigfaissJNI.read_index__SWIG_4(SWIGTYPE_p_faiss__IOReader.getCPtr(reader), io_flags);
return (cPtr == 0) ? null : new Index(cPtr, true);
}
public static Index read_index(SWIGTYPE_p_faiss__IOReader reader) {
long cPtr = swigfaissJNI.read_index__SWIG_5(SWIGTYPE_p_faiss__IOReader.getCPtr(reader));
return (cPtr == 0) ? null : new Index(cPtr, true);
}
public static IndexBinary read_index_binary(String fname, int io_flags) {
long cPtr = swigfaissJNI.read_index_binary__SWIG_0(fname, io_flags);
return (cPtr == 0) ? null : new IndexBinary(cPtr, true);
}
public static IndexBinary read_index_binary(String fname) {
long cPtr = swigfaissJNI.read_index_binary__SWIG_1(fname);
return (cPtr == 0) ? null : new IndexBinary(cPtr, true);
}
public static IndexBinary read_index_binary(SWIGTYPE_p_FILE f, int io_flags) {
long cPtr = swigfaissJNI.read_index_binary__SWIG_2(SWIGTYPE_p_FILE.getCPtr(f), io_flags);
return (cPtr == 0) ? null : new IndexBinary(cPtr, true);
}
public static IndexBinary read_index_binary(SWIGTYPE_p_FILE f) {
long cPtr = swigfaissJNI.read_index_binary__SWIG_3(SWIGTYPE_p_FILE.getCPtr(f));
return (cPtr == 0) ? null : new IndexBinary(cPtr, true);
}
public static IndexBinary read_index_binary(SWIGTYPE_p_faiss__IOReader reader, int io_flags) {
long cPtr = swigfaissJNI.read_index_binary__SWIG_4(SWIGTYPE_p_faiss__IOReader.getCPtr(reader), io_flags);
return (cPtr == 0) ? null : new IndexBinary(cPtr, true);
}
public static IndexBinary read_index_binary(SWIGTYPE_p_faiss__IOReader reader) {
long cPtr = swigfaissJNI.read_index_binary__SWIG_5(SWIGTYPE_p_faiss__IOReader.getCPtr(reader));
return (cPtr == 0) ? null : new IndexBinary(cPtr, true);
}
public static void write_VectorTransform(VectorTransform vt, String fname) {
swigfaissJNI.write_VectorTransform(VectorTransform.getCPtr(vt), vt, fname);
}
public static VectorTransform read_VectorTransform(String fname) {
long cPtr = swigfaissJNI.read_VectorTransform(fname);
return (cPtr == 0) ? null : new VectorTransform(cPtr, true);
}
public static ProductQuantizer read_ProductQuantizer(String fname) {
long cPtr = swigfaissJNI.read_ProductQuantizer__SWIG_0(fname);
return (cPtr == 0) ? null : new ProductQuantizer(cPtr, true);
}
public static ProductQuantizer read_ProductQuantizer(SWIGTYPE_p_faiss__IOReader reader) {
long cPtr = swigfaissJNI.read_ProductQuantizer__SWIG_1(SWIGTYPE_p_faiss__IOReader.getCPtr(reader));
return (cPtr == 0) ? null : new ProductQuantizer(cPtr, true);
}
public static void write_ProductQuantizer(ProductQuantizer pq, String fname) {
swigfaissJNI.write_ProductQuantizer__SWIG_0(ProductQuantizer.getCPtr(pq), pq, fname);
}
public static void write_ProductQuantizer(ProductQuantizer pq, SWIGTYPE_p_faiss__IOWriter f) {
swigfaissJNI.write_ProductQuantizer__SWIG_1(ProductQuantizer.getCPtr(pq), pq, SWIGTYPE_p_faiss__IOWriter.getCPtr(f));
}
public static void write_InvertedLists(InvertedLists ils, SWIGTYPE_p_faiss__IOWriter f) {
swigfaissJNI.write_InvertedLists(InvertedLists.getCPtr(ils), ils, SWIGTYPE_p_faiss__IOWriter.getCPtr(f));
}
public static InvertedLists read_InvertedLists(SWIGTYPE_p_faiss__IOReader reader, int io_flags) {
long cPtr = swigfaissJNI.read_InvertedLists__SWIG_0(SWIGTYPE_p_faiss__IOReader.getCPtr(reader), io_flags);
return (cPtr == 0) ? null : new InvertedLists(cPtr, false);
}
public static InvertedLists read_InvertedLists(SWIGTYPE_p_faiss__IOReader reader) {
long cPtr = swigfaissJNI.read_InvertedLists__SWIG_1(SWIGTYPE_p_faiss__IOReader.getCPtr(reader));
return (cPtr == 0) ? null : new InvertedLists(cPtr, false);
}
public static Index index_factory(int d, String description, MetricType metric) {
long cPtr = swigfaissJNI.index_factory__SWIG_0(d, description, metric.swigValue());
return (cPtr == 0) ? null : new Index(cPtr, true);
}
public static Index index_factory(int d, String description) {
long cPtr = swigfaissJNI.index_factory__SWIG_1(d, description);
return (cPtr == 0) ? null : new Index(cPtr, true);
}
public static void setIndex_factory_verbose(int value) {
swigfaissJNI.index_factory_verbose_set(value);
}
public static int getIndex_factory_verbose() {
return swigfaissJNI.index_factory_verbose_get();
}
public static IndexBinary index_binary_factory(int d, String description) {
long cPtr = swigfaissJNI.index_binary_factory(d, description);
return (cPtr == 0) ? null : new IndexBinary(cPtr, true);
}
public static void simd_histogram_8(SWIGTYPE_p_uint16_t data, int n, SWIGTYPE_p_uint16_t min, int shift, SWIGTYPE_p_int hist) {
swigfaissJNI.simd_histogram_8(SWIGTYPE_p_uint16_t.getCPtr(data), n, SWIGTYPE_p_uint16_t.getCPtr(min), shift, SWIGTYPE_p_int.getCPtr(hist));
}
public static void simd_histogram_16(SWIGTYPE_p_uint16_t data, int n, SWIGTYPE_p_uint16_t min, int shift, SWIGTYPE_p_int hist) {
swigfaissJNI.simd_histogram_16(SWIGTYPE_p_uint16_t.getCPtr(data), n, SWIGTYPE_p_uint16_t.getCPtr(min), shift, SWIGTYPE_p_int.getCPtr(hist));
}
public static void setPartition_stats(PartitionStats value) {
swigfaissJNI.partition_stats_set(PartitionStats.getCPtr(value), value);
}
public static PartitionStats getPartition_stats() {
long cPtr = swigfaissJNI.partition_stats_get();
return (cPtr == 0) ? null : new PartitionStats(cPtr, false);
}
public static float CMin_float_partition_fuzzy(SWIGTYPE_p_float vals, LongVector ids, long n, long q_min, long q_max, SWIGTYPE_p_unsigned_long q_out) {
return swigfaissJNI.CMin_float_partition_fuzzy(SWIGTYPE_p_float.getCPtr(vals), SWIGTYPE_p_long_long.getCPtr(ids.data()), ids, n, q_min, q_max, SWIGTYPE_p_unsigned_long.getCPtr(q_out));
}
public static float CMax_float_partition_fuzzy(SWIGTYPE_p_float vals, LongVector ids, long n, long q_min, long q_max, SWIGTYPE_p_unsigned_long q_out) {
return swigfaissJNI.CMax_float_partition_fuzzy(SWIGTYPE_p_float.getCPtr(vals), SWIGTYPE_p_long_long.getCPtr(ids.data()), ids, n, q_min, q_max, SWIGTYPE_p_unsigned_long.getCPtr(q_out));
}
public static SWIGTYPE_p_uint16_t CMax_uint16_partition_fuzzy(SWIGTYPE_p_uint16_t vals, LongVector ids, long n, long q_min, long q_max, SWIGTYPE_p_unsigned_long q_out) {
return new SWIGTYPE_p_uint16_t(swigfaissJNI.CMax_uint16_partition_fuzzy__SWIG_0(SWIGTYPE_p_uint16_t.getCPtr(vals), SWIGTYPE_p_long_long.getCPtr(ids.data()), ids, n, q_min, q_max, SWIGTYPE_p_unsigned_long.getCPtr(q_out)), true);
}
public static SWIGTYPE_p_uint16_t CMin_uint16_partition_fuzzy(SWIGTYPE_p_uint16_t vals, LongVector ids, long n, long q_min, long q_max, SWIGTYPE_p_unsigned_long q_out) {
return new SWIGTYPE_p_uint16_t(swigfaissJNI.CMin_uint16_partition_fuzzy__SWIG_0(SWIGTYPE_p_uint16_t.getCPtr(vals), SWIGTYPE_p_long_long.getCPtr(ids.data()), ids, n, q_min, q_max, SWIGTYPE_p_unsigned_long.getCPtr(q_out)), true);
}
public static SWIGTYPE_p_uint16_t CMax_uint16_partition_fuzzy(SWIGTYPE_p_uint16_t vals, SWIGTYPE_p_int ids, long n, long q_min, long q_max, SWIGTYPE_p_unsigned_long q_out) {
return new SWIGTYPE_p_uint16_t(swigfaissJNI.CMax_uint16_partition_fuzzy__SWIG_1(SWIGTYPE_p_uint16_t.getCPtr(vals), SWIGTYPE_p_int.getCPtr(ids), n, q_min, q_max, SWIGTYPE_p_unsigned_long.getCPtr(q_out)), true);
}
public static SWIGTYPE_p_uint16_t CMin_uint16_partition_fuzzy(SWIGTYPE_p_uint16_t vals, SWIGTYPE_p_int ids, long n, long q_min, long q_max, SWIGTYPE_p_unsigned_long q_out) {
return new SWIGTYPE_p_uint16_t(swigfaissJNI.CMin_uint16_partition_fuzzy__SWIG_1(SWIGTYPE_p_uint16_t.getCPtr(vals), SWIGTYPE_p_int.getCPtr(ids), n, q_min, q_max, SWIGTYPE_p_unsigned_long.getCPtr(q_out)), true);
}
public static void omp_set_num_threads(int num_threads) {
swigfaissJNI.omp_set_num_threads(num_threads);
}
public static int omp_get_max_threads() {
return swigfaissJNI.omp_get_max_threads();
}
public static SWIGTYPE_p_void memcpy(SWIGTYPE_p_void dest, SWIGTYPE_p_void src, long n) {
long cPtr = swigfaissJNI.memcpy(SWIGTYPE_p_void.getCPtr(dest), SWIGTYPE_p_void.getCPtr(src), n);
return (cPtr == 0) ? null : new SWIGTYPE_p_void(cPtr, false);
}
public static SWIGTYPE_p_float cast_integer_to_float_ptr(int x) {
long cPtr = swigfaissJNI.cast_integer_to_float_ptr(x);
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
}
public static SWIGTYPE_p_long cast_integer_to_long_ptr(int x) {
long cPtr = swigfaissJNI.cast_integer_to_long_ptr(x);
return (cPtr == 0) ? null : new SWIGTYPE_p_long(cPtr, false);
}
public static SWIGTYPE_p_int cast_integer_to_int_ptr(int x) {
long cPtr = swigfaissJNI.cast_integer_to_int_ptr(x);
return (cPtr == 0) ? null : new SWIGTYPE_p_int(cPtr, false);
}
public static void ignore_SIGTTIN() {
swigfaissJNI.ignore_SIGTTIN();
}
}

View File

@ -1,15 +0,0 @@
/* ----------------------------------------------------------------------------
* This file was automatically generated by SWIG (http://www.swig.org).
* Version 4.0.2
*
* Do not make changes to this file unless you know what you are doing--modify
* the SWIG interface file instead.
* ----------------------------------------------------------------------------- */
package com.twitter.ann.faiss;
public interface swigfaissConstants {
public final static int FAISS_VERSION_MAJOR = swigfaissJNI.FAISS_VERSION_MAJOR_get();
public final static int FAISS_VERSION_MINOR = swigfaissJNI.FAISS_VERSION_MINOR_get();
public final static int FAISS_VERSION_PATCH = swigfaissJNI.FAISS_VERSION_PATCH_get();
}

File diff suppressed because it is too large Load Diff

View File

@ -1,18 +0,0 @@
java_library(
sources = ["*.java"],
compiler_option_sets = ["fatal_warnings"],
platform = "java8",
tags = ["bazel-compatible"],
dependencies = [
"3rdparty/jvm/com/google/guava",
"3rdparty/jvm/com/google/inject:guice",
"3rdparty/jvm/com/twitter/bijection:core",
"3rdparty/jvm/commons-lang",
"3rdparty/jvm/org/apache/thrift",
"ann/src/main/scala/com/twitter/ann/common",
"ann/src/main/thrift/com/twitter/ann/common:ann-common-java",
"mediaservices/commons/src/main/scala:futuretracker",
"scrooge/scrooge-core",
"src/java/com/twitter/search/common/file",
],
)

Binary file not shown.

View File

@ -1,8 +0,0 @@
package com.twitter.ann.hnsw;
public interface DistanceFunction<T, Q> {
/**
* Distance between two items.
*/
float distance(T t, Q q);
}

View File

@ -1,23 +0,0 @@
package com.twitter.ann.hnsw;
/**
* An item associated with a float distance
* @param <T> The type of the item.
*/
public class DistancedItem<T> {
private final T item;
private final float distance;
public DistancedItem(T item, float distance) {
this.item = item;
this.distance = distance;
}
public T getItem() {
return item;
}
public float getDistance() {
return distance;
}
}

View File

@ -1,196 +0,0 @@
package com.twitter.ann.hnsw;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;
/**
* Container for items with their distance.
*
* @param <U> Type of origin/reference element.
* @param <T> Type of element that the queue will hold
*/
public class DistancedItemQueue<U, T> implements Iterable<DistancedItem<T>> {
private final U origin;
private final DistanceFunction<U, T> distFn;
private final PriorityQueue<DistancedItem<T>> queue;
private final boolean minQueue;
/**
* Creates ontainer for items with their distances.
*
* @param origin Origin (reference) point
* @param initial Initial list of elements to add in the structure
* @param minQueue True for min queue, False for max queue
* @param distFn Distance function
*/
public DistancedItemQueue(
U origin,
List<T> initial,
boolean minQueue,
DistanceFunction<U, T> distFn
) {
this.origin = origin;
this.distFn = distFn;
this.minQueue = minQueue;
final Comparator<DistancedItem<T>> cmp;
if (minQueue) {
cmp = (o1, o2) -> Float.compare(o1.getDistance(), o2.getDistance());
} else {
cmp = (o1, o2) -> Float.compare(o2.getDistance(), o1.getDistance());
}
this.queue = new PriorityQueue<>(cmp);
enqueueAll(initial);
new DistancedItemQueue<>(origin, distFn, queue, minQueue);
}
private DistancedItemQueue(
U origin,
DistanceFunction<U, T> distFn,
PriorityQueue<DistancedItem<T>> queue,
boolean minQueue
) {
this.origin = origin;
this.distFn = distFn;
this.queue = queue;
this.minQueue = minQueue;
}
/**
* Enqueues all the items into the queue.
*/
public void enqueueAll(List<T> list) {
for (T t : list) {
enqueue(t);
}
}
/**
* Return if queue is non empty or not
*
* @return true if queue is not empty else false
*/
public boolean nonEmpty() {
return !queue.isEmpty();
}
/**
* Return root of the queue
*
* @return root of the queue i.e min/max element depending upon min-max queue
*/
public DistancedItem<T> peek() {
return queue.peek();
}
/**
* Dequeue root of the queue.
*
* @return remove and return root of the queue i.e min/max element depending upon min-max queue
*/
public DistancedItem<T> dequeue() {
return queue.poll();
}
/**
* Dequeue all the elements from queueu with ordering mantained
*
* @return remove all the elements in the order of the queue i.e min/max queue.
*/
public List<DistancedItem<T>> dequeueAll() {
final List<DistancedItem<T>> list = new ArrayList<>(queue.size());
while (!queue.isEmpty()) {
list.add(queue.poll());
}
return list;
}
/**
* Convert queue to list
*
* @return list of elements of queue with distance and without any specific ordering
*/
public List<DistancedItem<T>> toList() {
return new ArrayList<>(queue);
}
/**
* Convert queue to list
*
* @return list of elements of queue without any specific ordering
*/
List<T> toListWithItem() {
List<T> list = new ArrayList<>(queue.size());
Iterator<DistancedItem<T>> itr = iterator();
while (itr.hasNext()) {
list.add(itr.next().getItem());
}
return list;
}
/**
* Enqueue an item into the queue
*/
public void enqueue(T item) {
queue.add(new DistancedItem<>(item, distFn.distance(origin, item)));
}
/**
* Enqueue an item into the queue with its distance.
*/
public void enqueue(T item, float distance) {
queue.add(new DistancedItem<>(item, distance));
}
/**
* Size
*
* @return size of the queue
*/
public int size() {
return queue.size();
}
/**
* Is Min queue
*
* @return true if min queue else false
*/
public boolean isMinQueue() {
return minQueue;
}
/**
* Returns origin (base element) of the queue
*
* @return origin of the queue
*/
public U getOrigin() {
return origin;
}
/**
* Return a new queue with ordering reversed.
*/
public DistancedItemQueue<U, T> reverse() {
final PriorityQueue<DistancedItem<T>> rqueue =
new PriorityQueue<>(queue.comparator().reversed());
if (queue.isEmpty()) {
return new DistancedItemQueue<>(origin, distFn, rqueue, !isMinQueue());
}
final Iterator<DistancedItem<T>> itr = iterator();
while (itr.hasNext()) {
rqueue.add(itr.next());
}
return new DistancedItemQueue<>(origin, distFn, rqueue, !isMinQueue());
}
@Override
public Iterator<DistancedItem<T>> iterator() {
return queue.iterator();
}
}

Binary file not shown.

View File

@ -1,711 +0,0 @@
package com.twitter.ann.hnsw;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.Function;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import org.apache.thrift.TException;
import com.twitter.ann.common.IndexOutputFile;
import com.twitter.ann.common.thriftjava.HnswInternalIndexMetadata;
import com.twitter.bijection.Injection;
import com.twitter.logging.Logger;
import com.twitter.mediaservices.commons.codec.ArrayByteBufferCodec;
import com.twitter.search.common.file.AbstractFile;
/**
* Typed multithreaded HNSW implementation supporting creation/querying of approximate nearest neighbour
* Paper: https://arxiv.org/pdf/1603.09320.pdf
* Multithreading impl based on NMSLIB version : https://github.com/nmslib/hnsw/blob/master/hnswlib/hnswalg.h
*
* @param <T> The type of items inserted / searched in the HNSW index.
* @param <Q> The type of KNN query.
*/
public class HnswIndex<T, Q> {
private static final Logger LOG = Logger.get(HnswIndex.class);
private static final String METADATA_FILE_NAME = "hnsw_internal_metadata";
private static final String GRAPH_FILE_NAME = "hnsw_internal_graph";
private static final int MAP_SIZE_FACTOR = 5;
private final DistanceFunction<T, T> distFnIndex;
private final DistanceFunction<Q, T> distFnQuery;
private final int efConstruction;
private final int maxM;
private final int maxM0;
private final double levelMultiplier;
private final AtomicReference<HnswMeta<T>> graphMeta = new AtomicReference<>();
private final Map<HnswNode<T>, ImmutableList<T>> graph;
// To take lock on vertex level
private final ConcurrentHashMap<T, ReadWriteLock> locks;
// To take lock on whole graph only if vertex addition is on layer above the current maxLevel
private final ReentrantLock globalLock;
private final Function<T, ReadWriteLock> lockProvider;
private final RandomProvider randomProvider;
// Probability of reevaluating connections of an element in the neighborhood during an update
// Can be used as a knob to adjust update_speed/search_speed tradeoff.
private final float updateNeighborProbability;
/**
* Creates instance of hnsw index.
*
* @param distFnIndex Any distance metric/non metric that specifies similarity between two items for indexing.
* @param distFnQuery Any distance metric/non metric that specifies similarity between item for which nearest neighbours queried for and already indexed item.
* @param efConstruction Provide speed vs index quality tradeoff, higher the value better the quality and higher the time to create index.
* Valid range of efConstruction can be anywhere between 1 and tens of thousand. Typically, it should be set so that a search of M
* neighbors with ef=efConstruction should end in recall>0.95.
* @param maxM Maximum connections per layer except 0th level.
* Optimal values between 5-48.
* Smaller M generally produces better result for lower recalls and/ or lower dimensional data,
* while bigger M is better for high recall and/ or high dimensional, data on the expense of more memory/disk usage
* @param expectedElements Approximate number of elements to be indexed
*/
protected HnswIndex(
DistanceFunction<T, T> distFnIndex,
DistanceFunction<Q, T> distFnQuery,
int efConstruction,
int maxM,
int expectedElements,
RandomProvider randomProvider
) {
this(distFnIndex,
distFnQuery,
efConstruction,
maxM,
expectedElements,
new HnswMeta<>(-1, Optional.empty()),
new ConcurrentHashMap<>(MAP_SIZE_FACTOR * expectedElements),
randomProvider
);
}
private HnswIndex(
DistanceFunction<T, T> distFnIndex,
DistanceFunction<Q, T> distFnQuery,
int efConstruction,
int maxM,
int expectedElements,
HnswMeta<T> graphMeta,
Map<HnswNode<T>, ImmutableList<T>> graph,
RandomProvider randomProvider
) {
this.distFnIndex = distFnIndex;
this.distFnQuery = distFnQuery;
this.efConstruction = efConstruction;
this.maxM = maxM;
this.maxM0 = 2 * maxM;
this.levelMultiplier = 1.0 / Math.log(1.0 * maxM);
this.graphMeta.set(graphMeta);
this.graph = graph;
this.locks = new ConcurrentHashMap<>(MAP_SIZE_FACTOR * expectedElements);
this.globalLock = new ReentrantLock();
this.lockProvider = key -> new ReentrantReadWriteLock();
this.randomProvider = randomProvider;
this.updateNeighborProbability = 1.0f;
}
/**
* wireConnectionForAllLayers finds connections for a new element and creates bi-direction links.
* The method assumes using a reentrant lock to link list reads.
*
* @param entryPoint the global entry point
* @param item the item for which the connections are found
* @param itemLevel the level of the added item (maximum layer in which we wire the connections)
* @param maxLayer the level of the entry point
*/
private void wireConnectionForAllLayers(final T entryPoint, final T item, final int itemLevel,
final int maxLayer, final boolean isUpdate) {
T curObj = entryPoint;
if (itemLevel < maxLayer) {
curObj = bestEntryPointUntilLayer(curObj, item, maxLayer, itemLevel, distFnIndex);
}
for (int level = Math.min(itemLevel, maxLayer); level >= 0; level--) {
final DistancedItemQueue<T, T> candidates =
searchLayerForCandidates(item, curObj, efConstruction, level, distFnIndex, isUpdate);
curObj = mutuallyConnectNewElement(item, candidates, level, isUpdate);
}
}
/**
* Insert the item into HNSW index.
*/
public void insert(final T item) throws IllegalDuplicateInsertException {
final Lock itemLock = locks.computeIfAbsent(item, lockProvider).writeLock();
itemLock.lock();
try {
final HnswMeta<T> metadata = graphMeta.get();
// If the graph already have the item, should not re-insert it again
// Need to check entry point in case we reinsert first item where is are no graph
// but only a entry point
if (graph.containsKey(HnswNode.from(0, item))
|| (metadata.getEntryPoint().isPresent()
&& Objects.equals(metadata.getEntryPoint().get(), item))) {
throw new IllegalDuplicateInsertException(
"Duplicate insertion is not supported: " + item);
}
final int curLevel = getRandomLevel();
Optional<T> entryPoint = metadata.getEntryPoint();
// The global lock prevents two threads from making changes to the entry point. This lock
// should get taken very infrequently. Something like log-base-levelMultiplier(num items)
// For a full explanation of locking see this document: http://go/hnsw-locking
int maxLevelCopy = metadata.getMaxLevel();
if (curLevel > maxLevelCopy) {
globalLock.lock();
// Re initialize the entryPoint and maxLevel in case these are changed by any other thread
// No need to check the condition again since,
// it is already checked at the end before updating entry point struct
// No need to unlock for optimization and keeping as is if condition fails since threads
// will not be entering this section a lot.
final HnswMeta<T> temp = graphMeta.get();
entryPoint = temp.getEntryPoint();
maxLevelCopy = temp.getMaxLevel();
}
if (entryPoint.isPresent()) {
wireConnectionForAllLayers(entryPoint.get(), item, curLevel, maxLevelCopy, false);
}
if (curLevel > maxLevelCopy) {
Preconditions.checkState(globalLock.isHeldByCurrentThread(),
"Global lock not held before updating entry point");
graphMeta.set(new HnswMeta<>(curLevel, Optional.of(item)));
}
} finally {
if (globalLock.isHeldByCurrentThread()) {
globalLock.unlock();
}
itemLock.unlock();
}
}
/**
* set connections of an element with synchronization
* The only other place that should have the lock for writing is during
* the element insertion
*/
private void setConnectionList(final T item, int layer, List<T> connections) {
final Lock candidateLock = locks.computeIfAbsent(item, lockProvider).writeLock();
candidateLock.lock();
try {
graph.put(
HnswNode.from(layer, item),
ImmutableList.copyOf(connections)
);
} finally {
candidateLock.unlock();
}
}
/**
* Reinsert the item into HNSW index.
* This method updates the links of an element assuming
* the element's distance function is changed externally (e.g. by updating the features)
*/
public void reInsert(final T item) {
final HnswMeta<T> metadata = graphMeta.get();
Optional<T> entryPoint = metadata.getEntryPoint();
Preconditions.checkState(entryPoint.isPresent(),
"Update cannot be performed if entry point is not present");
// This is a check for the single element case
if (entryPoint.get().equals(item) && graph.isEmpty()) {
return;
}
Preconditions.checkState(graph.containsKey(HnswNode.from(0, item)),
"Graph does not contain the item to be updated at level 0");
int curLevel = 0;
int maxLevelCopy = metadata.getMaxLevel();
for (int layer = maxLevelCopy; layer >= 0; layer--) {
if (graph.containsKey(HnswNode.from(layer, item))) {
curLevel = layer;
break;
}
}
// Updating the links of the elements from the 1-hop radius of the updated element
for (int layer = 0; layer <= curLevel; layer++) {
// Filling the element sets for candidates and updated elements
final HashSet<T> setCand = new HashSet<T>();
final HashSet<T> setNeigh = new HashSet<T>();
final List<T> listOneHop = getConnectionListForRead(item, layer);
if (listOneHop.isEmpty()) {
LOG.debug("No links for the updated element. Empty dataset?");
continue;
}
setCand.add(item);
for (T elOneHop : listOneHop) {
setCand.add(elOneHop);
if (randomProvider.get().nextFloat() > updateNeighborProbability) {
continue;
}
setNeigh.add(elOneHop);
final List<T> listTwoHop = getConnectionListForRead(elOneHop, layer);
if (listTwoHop.isEmpty()) {
LOG.debug("No links for the updated element. Empty dataset?");
}
for (T oneHopEl : listTwoHop) {
setCand.add(oneHopEl);
}
}
// No need to update the item itself, so remove it
setNeigh.remove(item);
// Updating the link lists of elements from setNeigh:
for (T neigh : setNeigh) {
final HashSet<T> setCopy = new HashSet<T>(setCand);
setCopy.remove(neigh);
int keepElementsNum = Math.min(efConstruction, setCopy.size());
final DistancedItemQueue<T, T> candidates = new DistancedItemQueue<>(
neigh,
ImmutableList.of(),
false,
distFnIndex
);
for (T cand : setCopy) {
final float distance = distFnIndex.distance(neigh, cand);
if (candidates.size() < keepElementsNum) {
candidates.enqueue(cand, distance);
} else {
if (distance < candidates.peek().getDistance()) {
candidates.dequeue();
candidates.enqueue(cand, distance);
}
}
}
final ImmutableList<T> neighbours = selectNearestNeighboursByHeuristic(
candidates,
layer == 0 ? maxM0 : maxM
);
final List<T> temp = getConnectionListForRead(neigh, layer);
if (temp.isEmpty()) {
LOG.debug("existing linkslist is empty. Corrupt index");
}
if (neighbours.isEmpty()) {
LOG.debug("predicted linkslist is empty. Corrupt index");
}
setConnectionList(neigh, layer, neighbours);
}
}
wireConnectionForAllLayers(metadata.getEntryPoint().get(), item, curLevel, maxLevelCopy, true);
}
/**
* This method can be used to get the graph statistics, specifically
* it prints the histogram of inbound connections for each element.
*/
private String getStats() {
int histogramMaxBins = 50;
int[] histogram = new int[histogramMaxBins];
HashMap<T, Integer> mmap = new HashMap<T, Integer>();
for (HnswNode<T> key : graph.keySet()) {
if (key.level == 0) {
List<T> linkList = getConnectionListForRead(key.item, key.level);
for (T node : linkList) {
int a = mmap.computeIfAbsent(node, k -> 0);
mmap.put(node, a + 1);
}
}
}
for (T key : mmap.keySet()) {
int ind = mmap.get(key) < histogramMaxBins - 1 ? mmap.get(key) : histogramMaxBins - 1;
histogram[ind]++;
}
int minNonZeroIndex;
for (minNonZeroIndex = histogramMaxBins - 1; minNonZeroIndex >= 0; minNonZeroIndex--) {
if (histogram[minNonZeroIndex] > 0) {
break;
}
}
String output = "";
for (int i = 0; i <= minNonZeroIndex; i++) {
output += "" + i + "\t" + histogram[i] / (0.01f * mmap.keySet().size()) + "\n";
}
return output;
}
private int getRandomLevel() {
return (int) (-Math.log(randomProvider.get().nextDouble()) * levelMultiplier);
}
/**
* Note that to avoid deadlocks it is important that this method is called after all the searches
* of the graph have completed. If you take a lock on any items discovered in the graph after
* this, you may get stuck waiting on a thread that is waiting for item to be fully inserted.
* <p>
* Note: when using concurrent writers we can miss connections that we would otherwise get.
* This will reduce the recall.
* <p>
* For a full explanation of locking see this document: http://go/hnsw-locking
* The method returns the closest nearest neighbor (can be used as an enter point)
*/
private T mutuallyConnectNewElement(
final T item,
final DistancedItemQueue<T, T> candidates, // Max queue
final int level,
final boolean isUpdate
) {
// Using maxM here. Its implementation is ambiguous in HNSW paper,
// so using the way it is getting used in Hnsw lib.
final ImmutableList<T> neighbours = selectNearestNeighboursByHeuristic(candidates, maxM);
setConnectionList(item, level, neighbours);
final int M = level == 0 ? maxM0 : maxM;
for (T nn : neighbours) {
if (nn.equals(item)) {
continue;
}
final Lock curLock = locks.computeIfAbsent(nn, lockProvider).writeLock();
curLock.lock();
try {
final HnswNode<T> key = HnswNode.from(level, nn);
final ImmutableList<T> connections = graph.getOrDefault(key, ImmutableList.of());
final boolean isItemAlreadyPresent =
isUpdate && connections.indexOf(item) != -1 ? true : false;
// If `item` is already present in the neighboring connections,
// then no need to modify any connections or run the search heuristics.
if (isItemAlreadyPresent) {
continue;
}
final ImmutableList<T> updatedConnections;
if (connections.size() < M) {
final List<T> temp = new ArrayList<>(connections);
temp.add(item);
updatedConnections = ImmutableList.copyOf(temp.iterator());
} else {
// Max Queue
final DistancedItemQueue<T, T> queue = new DistancedItemQueue<>(
nn,
connections,
false,
distFnIndex
);
queue.enqueue(item);
updatedConnections = selectNearestNeighboursByHeuristic(queue, M);
}
if (updatedConnections.isEmpty()) {
LOG.debug("Internal error: predicted linkslist is empty");
}
graph.put(key, updatedConnections);
} finally {
curLock.unlock();
}
}
return neighbours.get(0);
}
/*
* bestEntryPointUntilLayer starts the graph search for item from the entry point
* until the searches reaches the selectedLayer layer.
* @return a point from selectedLayer layer, was the closest on the (selectedLayer+1) layer
*/
private <K> T bestEntryPointUntilLayer(
final T entryPoint,
final K item,
int maxLayer,
int selectedLayer,
DistanceFunction<K, T> distFn
) {
T curObj = entryPoint;
if (selectedLayer < maxLayer) {
float curDist = distFn.distance(item, curObj);
for (int level = maxLayer; level > selectedLayer; level--) {
boolean changed = true;
while (changed) {
changed = false;
final List<T> list = getConnectionListForRead(curObj, level);
for (T nn : list) {
final float tempDist = distFn.distance(item, nn);
if (tempDist < curDist) {
curDist = tempDist;
curObj = nn;
changed = true;
}
}
}
}
}
return curObj;
}
@VisibleForTesting
protected ImmutableList<T> selectNearestNeighboursByHeuristic(
final DistancedItemQueue<T, T> candidates, // Max queue
final int maxConnections
) {
Preconditions.checkState(!candidates.isMinQueue(),
"candidates in selectNearestNeighboursByHeuristic should be a max queue");
final T baseElement = candidates.getOrigin();
if (candidates.size() <= maxConnections) {
List<T> list = candidates.toListWithItem();
list.remove(baseElement);
return ImmutableList.copyOf(list);
} else {
final List<T> resSet = new ArrayList<>(maxConnections);
// Min queue for closest elements first
final DistancedItemQueue<T, T> minQueue = candidates.reverse();
while (minQueue.nonEmpty()) {
if (resSet.size() >= maxConnections) {
break;
}
final DistancedItem<T> candidate = minQueue.dequeue();
// We do not want to creates loops:
// While heuristic is used only for creating the links
if (candidate.getItem().equals(baseElement)) {
continue;
}
boolean toInclude = true;
for (T e : resSet) {
// Do not include candidate if the distance from candidate to any of existing item in
// resSet is closer to the distance from the candidate to the item. By doing this, the
// connection of graph will be more diverse, and in case of highly clustered data set,
// connections will be made between clusters instead of all being in the same cluster.
final float dist = distFnIndex.distance(e, candidate.getItem());
if (dist < candidate.getDistance()) {
toInclude = false;
break;
}
}
if (toInclude) {
resSet.add(candidate.getItem());
}
}
return ImmutableList.copyOf(resSet);
}
}
/**
* Search the index for the neighbours.
*
* @param query Query
* @param numOfNeighbours Number of neighbours to search for.
* @param ef This param controls the accuracy of the search.
* Bigger the ef better the accuracy on the expense of latency.
* Keep it atleast number of neighbours to find.
* @return Neighbours
*/
public List<DistancedItem<T>> searchKnn(final Q query, final int numOfNeighbours, final int ef) {
final HnswMeta<T> metadata = graphMeta.get();
if (metadata.getEntryPoint().isPresent()) {
T entryPoint = bestEntryPointUntilLayer(metadata.getEntryPoint().get(),
query, metadata.getMaxLevel(), 0, distFnQuery);
// Get the actual neighbours from 0th layer
final List<DistancedItem<T>> neighbours =
searchLayerForCandidates(query, entryPoint, Math.max(ef, numOfNeighbours),
0, distFnQuery, false).dequeueAll();
Collections.reverse(neighbours);
return neighbours.size() > numOfNeighbours
? neighbours.subList(0, numOfNeighbours) : neighbours;
} else {
return Collections.emptyList();
}
}
// This method is currently not used
// It is needed for debugging purposes only
private void checkIntegrity(String message) {
final HnswMeta<T> metadata = graphMeta.get();
for (HnswNode<T> node : graph.keySet()) {
List<T> linkList = graph.get(node);
for (T el : linkList) {
if (el.equals(node.item)) {
LOG.debug(message);
throw new RuntimeException("integrity check failed");
}
}
}
}
private <K> DistancedItemQueue<K, T> searchLayerForCandidates(
final K item,
final T entryPoint,
final int ef,
final int level,
final DistanceFunction<K, T> distFn,
boolean isUpdate
) {
// Min queue
final DistancedItemQueue<K, T> cQueue = new DistancedItemQueue<>(
item,
Collections.singletonList(entryPoint),
true,
distFn
);
// Max Queue
final DistancedItemQueue<K, T> wQueue = cQueue.reverse();
final Set<T> visited = new HashSet<>();
float lowerBoundDistance = wQueue.peek().getDistance();
visited.add(entryPoint);
while (cQueue.nonEmpty()) {
final DistancedItem<T> candidate = cQueue.peek();
if (candidate.getDistance() > lowerBoundDistance) {
break;
}
cQueue.dequeue();
final List<T> list = getConnectionListForRead(candidate.getItem(), level);
for (T nn : list) {
if (!visited.contains(nn)) {
visited.add(nn);
final float distance = distFn.distance(item, nn);
if (wQueue.size() < ef || distance < wQueue.peek().getDistance()) {
cQueue.enqueue(nn, distance);
if (isUpdate && item.equals(nn)) {
continue;
}
wQueue.enqueue(nn, distance);
if (wQueue.size() > ef) {
wQueue.dequeue();
}
lowerBoundDistance = wQueue.peek().getDistance();
}
}
}
}
return wQueue;
}
/**
* Serialize hnsw index
*/
public void toDirectory(IndexOutputFile indexOutputFile, Injection<T, byte[]> injection)
throws IOException, TException {
final int totalGraphEntries = HnswIndexIOUtil.saveHnswGraphEntries(
graph,
indexOutputFile.createFile(GRAPH_FILE_NAME).getOutputStream(),
injection);
HnswIndexIOUtil.saveMetadata(
graphMeta.get(),
efConstruction,
maxM,
totalGraphEntries,
injection,
indexOutputFile.createFile(METADATA_FILE_NAME).getOutputStream());
}
/**
* Load hnsw index
*/
public static <T, Q> HnswIndex<T, Q> loadHnswIndex(
DistanceFunction<T, T> distFnIndex,
DistanceFunction<Q, T> distFnQuery,
AbstractFile directory,
Injection<T, byte[]> injection,
RandomProvider randomProvider) throws IOException, TException {
final AbstractFile graphFile = directory.getChild(GRAPH_FILE_NAME);
final AbstractFile metadataFile = directory.getChild(METADATA_FILE_NAME);
final HnswInternalIndexMetadata metadata = HnswIndexIOUtil.loadMetadata(metadataFile);
final Map<HnswNode<T>, ImmutableList<T>> graph =
HnswIndexIOUtil.loadHnswGraph(graphFile, injection, metadata.numElements);
final ByteBuffer entryPointBB = metadata.entryPoint;
final HnswMeta<T> graphMeta = new HnswMeta<>(
metadata.maxLevel,
entryPointBB == null ? Optional.empty()
: Optional.of(injection.invert(ArrayByteBufferCodec.decode(entryPointBB)).get())
);
return new HnswIndex<>(
distFnIndex,
distFnQuery,
metadata.efConstruction,
metadata.maxM,
metadata.numElements,
graphMeta,
graph,
randomProvider
);
}
private List<T> getConnectionListForRead(T node, int level) {
final Lock curLock = locks.computeIfAbsent(node, lockProvider).readLock();
curLock.lock();
final List<T> list;
try {
list = graph
.getOrDefault(HnswNode.from(level, node), ImmutableList.of());
} finally {
curLock.unlock();
}
return list;
}
@VisibleForTesting
AtomicReference<HnswMeta<T>> getGraphMeta() {
return graphMeta;
}
@VisibleForTesting
Map<T, ReadWriteLock> getLocks() {
return locks;
}
@VisibleForTesting
Map<HnswNode<T>, ImmutableList<T>> getGraph() {
return graph;
}
public interface RandomProvider {
/**
* RandomProvider interface made public for scala 2.12 compat
*/
Random get();
}
}

View File

@ -1,133 +0,0 @@
package com.twitter.ann.hnsw;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import com.google.common.collect.ImmutableList;
import org.apache.thrift.TDeserializer;
import org.apache.thrift.TException;
import org.apache.thrift.TSerializer;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.transport.TIOStreamTransport;
import org.apache.thrift.transport.TTransportException;
import com.twitter.ann.common.thriftjava.HnswGraphEntry;
import com.twitter.ann.common.thriftjava.HnswInternalIndexMetadata;
import com.twitter.bijection.Injection;
import com.twitter.mediaservices.commons.codec.ArrayByteBufferCodec;
import com.twitter.search.common.file.AbstractFile;
public final class HnswIndexIOUtil {
private HnswIndexIOUtil() {
}
/**
* Save thrift object in file
*/
public static <T> void saveMetadata(
HnswMeta<T> graphMeta,
int efConstruction,
int maxM,
int numElements,
Injection<T, byte[]> injection,
OutputStream outputStream
) throws IOException, TException {
final int maxLevel = graphMeta.getMaxLevel();
final HnswInternalIndexMetadata metadata = new HnswInternalIndexMetadata(
maxLevel,
efConstruction,
maxM,
numElements
);
if (graphMeta.getEntryPoint().isPresent()) {
metadata.setEntryPoint(injection.apply(graphMeta.getEntryPoint().get()));
}
final TSerializer serializer = new TSerializer(new TBinaryProtocol.Factory());
outputStream.write(serializer.serialize(metadata));
outputStream.close();
}
/**
* Load Hnsw index metadata
*/
public static HnswInternalIndexMetadata loadMetadata(AbstractFile file)
throws IOException, TException {
final HnswInternalIndexMetadata obj = new HnswInternalIndexMetadata();
final TDeserializer deserializer = new TDeserializer(new TBinaryProtocol.Factory());
deserializer.deserialize(obj, file.getByteSource().read());
return obj;
}
/**
* Load Hnsw graph entries from file
*/
public static <T> Map<HnswNode<T>, ImmutableList<T>> loadHnswGraph(
AbstractFile file,
Injection<T, byte[]> injection,
int numElements
) throws IOException, TException {
final InputStream stream = file.getByteSource().openBufferedStream();
final TProtocol protocol = new TBinaryProtocol(new TIOStreamTransport(stream));
final Map<HnswNode<T>, ImmutableList<T>> graph =
new HashMap<>(numElements);
while (true) {
try {
final HnswGraphEntry entry = new HnswGraphEntry();
entry.read(protocol);
final HnswNode<T> node = HnswNode.from(entry.level,
injection.invert(ArrayByteBufferCodec.decode(entry.key)).get());
final List<T> list = entry.getNeighbours().stream()
.map(bb -> injection.invert(ArrayByteBufferCodec.decode(bb)).get())
.collect(Collectors.toList());
graph.put(node, ImmutableList.copyOf(list.iterator()));
} catch (TException e) {
if (e instanceof TTransportException
&& TTransportException.class.cast(e).getType() == TTransportException.END_OF_FILE) {
stream.close();
break;
}
stream.close();
throw e;
}
}
return graph;
}
/**
* Save hnsw graph in file
*
* @return number of keys in the graph
*/
public static <T> int saveHnswGraphEntries(
Map<HnswNode<T>, ImmutableList<T>> graph,
OutputStream outputStream,
Injection<T, byte[]> injection
) throws IOException, TException {
final TProtocol protocol = new TBinaryProtocol(new TIOStreamTransport(outputStream));
final Set<HnswNode<T>> nodes = graph.keySet();
for (HnswNode<T> node : nodes) {
final HnswGraphEntry entry = new HnswGraphEntry();
entry.setLevel(node.level);
entry.setKey(injection.apply(node.item));
final List<ByteBuffer> nn = graph.getOrDefault(node, ImmutableList.of()).stream()
.map(t -> ByteBuffer.wrap(injection.apply(t)))
.collect(Collectors.toList());
entry.setNeighbours(nn);
entry.write(protocol);
}
outputStream.close();
return nodes.size();
}
}

Binary file not shown.

View File

@ -1,45 +0,0 @@
package com.twitter.ann.hnsw;
import java.util.Objects;
import java.util.Optional;
class HnswMeta<T> {
private final int maxLevel;
private final Optional<T> entryPoint;
HnswMeta(int maxLevel, Optional<T> entryPoint) {
this.maxLevel = maxLevel;
this.entryPoint = entryPoint;
}
public int getMaxLevel() {
return maxLevel;
}
public Optional<T> getEntryPoint() {
return entryPoint;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
HnswMeta<?> hnswMeta = (HnswMeta<?>) o;
return maxLevel == hnswMeta.maxLevel
&& Objects.equals(entryPoint, hnswMeta.entryPoint);
}
@Override
public int hashCode() {
return Objects.hash(maxLevel, entryPoint);
}
@Override
public String toString() {
return "HnswMeta{maxLevel=" + maxLevel + ", entryPoint=" + entryPoint + '}';
}
}

Binary file not shown.

View File

@ -1,45 +0,0 @@
package com.twitter.ann.hnsw;
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
public class HnswNode<T> {
public final int level;
public final T item;
public HnswNode(int level, T item) {
this.level = level;
this.item = item;
}
/**
* Create a hnsw node.
*/
public static <T> HnswNode<T> from(int level, T item) {
return new HnswNode<>(level, item);
}
@Override
public boolean equals(Object o) {
if (o == this) {
return true;
}
if (!(o instanceof HnswNode)) {
return false;
}
HnswNode<?> that = (HnswNode<?>) o;
return new EqualsBuilder()
.append(this.item, that.item)
.append(this.level, that.level)
.isEquals();
}
@Override
public int hashCode() {
return new HashCodeBuilder()
.append(item)
.append(level)
.toHashCode();
}
}

View File

@ -1,7 +0,0 @@
package com.twitter.ann.hnsw;
public class IllegalDuplicateInsertException extends Exception {
public IllegalDuplicateInsertException(String message) {
super(message);
}
}

View File

@ -1,38 +0,0 @@
resources(
name = "sql",
sources = ["bq.sql"],
)
python3_library(
name = "faiss_indexing",
sources = ["**/*.py"],
tags = ["bazel-compatible"],
dependencies = [
":sql",
"3rdparty/python/apache-beam:default",
"3rdparty/python/faiss-gpu:default",
"3rdparty/python/gcsfs:default",
"3rdparty/python/google-cloud-bigquery:default",
"3rdparty/python/google-cloud-storage",
"3rdparty/python/numpy:default",
"3rdparty/python/pandas:default",
"3rdparty/python/pandas-gbq:default",
"3rdparty/python/pyarrow:default",
"src/python/twitter/ml/common/apache_beam",
],
)
python37_binary(
name = "faiss_indexing_bin",
sources = ["faiss_index_bq_dataset.py"],
platforms = [
"current",
"linux_x86_64",
],
tags = ["no-mypy"],
zip_safe = False,
dependencies = [
":faiss_indexing",
"3rdparty/python/_closures/ann/src/main/python/dataflow:faiss_indexing_bin",
],
)

Binary file not shown.

Binary file not shown.

View File

@ -1,6 +0,0 @@
WITH maxts as (SELECT as value MAX(ts) as ts FROM `twttr-recos-ml-prod.ssedhain.twhin_tweet_avg_embedding`)
SELECT entityId, embedding
FROM `twttr-recos-ml-prod.ssedhain.twhin_tweet_avg_embedding`
WHERE ts >= (select max(maxts) from maxts)
AND DATE(TIMESTAMP_MILLIS(createdAt)) <= (select max(maxts) from maxts)
AND DATE(TIMESTAMP_MILLIS(createdAt)) >= DATE_SUB((select max(maxts) from maxts), INTERVAL 1 DAY)

View File

@ -1,232 +0,0 @@
import argparse
import logging
import os
import pkgutil
import sys
from urllib.parse import urlsplit
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
import faiss
def parse_d6w_config(argv=None):
"""Parse d6w config.
:param argv: d6w config
:return: dictionary containing d6w config
"""
parser = argparse.ArgumentParser(
description="See https://docbird.twitter.biz/d6w/model.html for any parameters inherited from d6w job config"
)
parser.add_argument("--job_name", dest="job_name", required=True, help="d6w attribute")
parser.add_argument("--project", dest="project", required=True, help="d6w attribute")
parser.add_argument(
"--staging_location", dest="staging_location", required=True, help="d6w attribute"
)
parser.add_argument("--temp_location", dest="temp_location", required=True, help="d6w attribute")
parser.add_argument(
"--output_location",
dest="output_location",
required=True,
help="GCS bucket and path where resulting artifacts are uploaded",
)
parser.add_argument(
"--service_account_email", dest="service_account_email", required=True, help="d6w attribute"
)
parser.add_argument(
"--factory_string",
dest="factory_string",
required=False,
help="FAISS factory string describing index to build. See https://github.com/facebookresearch/faiss/wiki/The-index-factory",
)
parser.add_argument(
"--metric",
dest="metric",
required=True,
help="Metric used to compute distance between embeddings. Valid values are 'l2', 'ip', 'l1', 'linf'",
)
parser.add_argument(
"--use_gpu",
dest="gpu",
required=True,
help="--use_gpu=yes if you want to use GPU during index building",
)
known_args, unknown_args = parser.parse_known_args(argv)
d6w_config = vars(known_args)
d6w_config["gpu"] = d6w_config["gpu"].lower() == "yes"
d6w_config["metric"] = parse_metric(d6w_config)
"""
WARNING: Currently, d6w (a Twitter tool used to deploy Dataflow jobs to GCP) and
PipelineOptions.for_dataflow_runner (a helper method in twitter.ml.common.apache_beam) do not
play nicely together. The helper method will overwrite some of the config specified in the d6w
file using the defaults in https://sourcegraph.twitter.biz/git.twitter.biz/source/-/blob/src/python/twitter/ml/common/apache_beam/__init__.py?L24.'
However, the d6w output message will still report that the config specified in the d6w file was used.
"""
logging.warning(
f"The following d6w config parameters will be overwritten by the defaults in "
f"https://sourcegraph.twitter.biz/git.twitter.biz/source/-/blob/src/python/twitter/ml/common/apache_beam/__init__.py?L24\n"
f"{str(unknown_args)}"
)
return d6w_config
def get_bq_query():
"""
Query is expected to return rows with unique entityId
"""
return pkgutil.get_data(__name__, "bq.sql").decode("utf-8")
def parse_metric(config):
metric_str = config["metric"].lower()
if metric_str == "l2":
return faiss.METRIC_L2
elif metric_str == "ip":
return faiss.METRIC_INNER_PRODUCT
elif metric_str == "l1":
return faiss.METRIC_L1
elif metric_str == "linf":
return faiss.METRIC_Linf
else:
raise Exception(f"Unknown metric: {metric_str}")
def run_pipeline(argv=[]):
config = parse_d6w_config(argv)
argv_with_extras = argv
if config["gpu"]:
argv_with_extras.extend(["--experiments", "use_runner_v2"])
argv_with_extras.extend(
["--experiments", "worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver"]
)
argv_with_extras.extend(
[
"--worker_harness_container_image",
"gcr.io/twttr-recos-ml-prod/dataflow-gpu/beam2_39_0_py3_7",
]
)
options = PipelineOptions(argv_with_extras)
output_bucket_name = urlsplit(config["output_location"]).netloc
with beam.Pipeline(options=options) as p:
input_data = p | "Read from BigQuery" >> beam.io.ReadFromBigQuery(
method=beam.io.ReadFromBigQuery.Method.DIRECT_READ,
query=get_bq_query(),
use_standard_sql=True,
)
index_built = input_data | "Build and upload index" >> beam.CombineGlobally(
MergeAndBuildIndex(
output_bucket_name,
config["output_location"],
config["factory_string"],
config["metric"],
config["gpu"],
)
)
# Make linter happy
index_built
class MergeAndBuildIndex(beam.CombineFn):
def __init__(self, bucket_name, gcs_output_path, factory_string, metric, gpu):
self.bucket_name = bucket_name
self.gcs_output_path = gcs_output_path
self.factory_string = factory_string
self.metric = metric
self.gpu = gpu
def create_accumulator(self):
return []
def add_input(self, accumulator, element):
accumulator.append(element)
return accumulator
def merge_accumulators(self, accumulators):
merged = []
for accum in accumulators:
merged.extend(accum)
return merged
def extract_output(self, rows):
# Reimports are needed on workers
import glob
import subprocess
import faiss
from google.cloud import storage
import numpy as np
client = storage.Client()
bucket = client.get_bucket(self.bucket_name)
logging.info("Building FAISS index")
logging.info(f"There are {len(rows)} rows")
ids = np.array([x["entityId"] for x in rows]).astype("long")
embeds = np.array([x["embedding"] for x in rows]).astype("float32")
dimensions = len(embeds[0])
N = ids.shape[0]
logging.info(f"There are {dimensions} dimensions")
if self.factory_string is None:
M = 48
divideable_dimensions = (dimensions // M) * M
if divideable_dimensions != dimensions:
opq_prefix = f"OPQ{M}_{divideable_dimensions}"
else:
opq_prefix = f"OPQ{M}"
clusters = N // 20
self.factory_string = f"{opq_prefix},IVF{clusters},PQ{M}"
logging.info(f"Factory string is {self.factory_string}, metric={self.metric}")
if self.gpu:
logging.info("Using GPU")
res = faiss.StandardGpuResources()
cpu_index = faiss.index_factory(dimensions, self.factory_string, self.metric)
cpu_index = faiss.IndexIDMap(cpu_index)
gpu_index = faiss.index_cpu_to_gpu(res, 0, cpu_index)
gpu_index.train(embeds)
gpu_index.add_with_ids(embeds, ids)
cpu_index = faiss.index_gpu_to_cpu(gpu_index)
else:
logging.info("Using CPU")
cpu_index = faiss.index_factory(dimensions, self.factory_string, self.metric)
cpu_index = faiss.IndexIDMap(cpu_index)
cpu_index.train(embeds)
cpu_index.add_with_ids(embeds, ids)
logging.info("Built faiss index")
local_path = "/indices"
logging.info(f"Writing indices to local {local_path}")
subprocess.run(f"mkdir -p {local_path}".strip().split())
local_index_path = os.path.join(local_path, "result.index")
faiss.write_index(cpu_index, local_index_path)
logging.info(f"Done writing indices to local {local_path}")
logging.info(f"Uploading to GCS with path {self.gcs_output_path}")
assert os.path.isdir(local_path)
for local_file in glob.glob(local_path + "/*"):
remote_path = os.path.join(
self.gcs_output_path.split("/")[-1], local_file[1 + len(local_path) :]
)
blob = bucket.blob(remote_path)
blob.upload_from_filename(local_file)
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
run_pipeline(sys.argv)

View File

@ -1,34 +0,0 @@
FROM --platform=linux/amd64 nvidia/cuda:11.2.2-cudnn8-runtime-ubuntu20.04
RUN \
# Add Deadsnakes repository that has a variety of Python packages for Ubuntu.
# See: https://launchpad.net/~deadsnakes/+archive/ubuntu/ppa
apt-key adv --keyserver keyserver.ubuntu.com --recv-keys F23C5A6CF475977595C89F51BA6932366A755776 \
&& echo "deb http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal main" >> /etc/apt/sources.list.d/custom.list \
&& echo "deb-src http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal main" >> /etc/apt/sources.list.d/custom.list \
&& apt-get update \
&& apt-get install -y curl \
python3.7 \
# With python3.8 package, distutils need to be installed separately.
python3.7-distutils \
python3-dev \
python3.7-dev \
libpython3.7-dev \
python3-apt \
gcc \
g++ \
&& rm -rf /var/lib/apt/lists/*
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.7 10
RUN rm -f /usr/bin/python3 && ln -s /usr/bin/python3.7 /usr/bin/python3
RUN \
curl https://bootstrap.pypa.io/get-pip.py | python \
&& pip3 install pip==22.0.3 \
&& python3 -m pip install --no-cache-dir apache-beam[gcp]==2.39.0
# Verify that there are no conflicting dependencies.
RUN pip3 check
# Copy the Apache Beam worker dependencies from the Beam Python 3.7 SDK image.
COPY --from=apache/beam_python3.7_sdk:2.39.0 /opt/apache/beam /opt/apache/beam
# Set the entrypoint to Apache Beam SDK worker launcher.
ENTRYPOINT [ "/opt/apache/beam/boot" ]

View File

@ -1,6 +0,0 @@
steps:
- name: 'gcr.io/cloud-builders/docker'
args: ['build', '-t', 'gcr.io/twttr-recos-ml-prod/dataflow-gpu/beam2_39_0_py3_7', '.']
- name: 'gcr.io/cloud-builders/docker'
args: ['push', 'gcr.io/twttr-recos-ml-prod/dataflow-gpu/beam2_39_0_py3_7']
images: ['gcr.io/twttr-recos-ml-prod/dataflow-gpu/beam2_39_0_py3_7']

View File

@ -1,44 +0,0 @@
package com.twitter.ann.annoy
import com.twitter.ann.common.RuntimeParams
import com.twitter.ann.common.thriftscala.AnnoyIndexMetadata
import com.twitter.bijection.Injection
import com.twitter.mediaservices.commons.codec.ThriftByteBufferCodec
import com.twitter.ann.common.thriftscala.{AnnoyRuntimeParam, RuntimeParams => ServiceRuntimeParams}
import scala.util.{Failure, Success, Try}
object AnnoyCommon {
private[annoy] lazy val MetadataCodec = new ThriftByteBufferCodec(AnnoyIndexMetadata)
private[annoy] val IndexFileName = "annoy_index"
private[annoy] val MetaDataFileName = "annoy_index_metadata"
private[annoy] val IndexIdMappingFileName = "annoy_index_id_mapping"
val RuntimeParamsInjection: Injection[AnnoyRuntimeParams, ServiceRuntimeParams] =
new Injection[AnnoyRuntimeParams, ServiceRuntimeParams] {
override def apply(scalaParams: AnnoyRuntimeParams): ServiceRuntimeParams = {
ServiceRuntimeParams.AnnoyParam(
AnnoyRuntimeParam(
scalaParams.nodesToExplore
)
)
}
override def invert(thriftParams: ServiceRuntimeParams): Try[AnnoyRuntimeParams] =
thriftParams match {
case ServiceRuntimeParams.AnnoyParam(annoyParam) =>
Success(
AnnoyRuntimeParams(annoyParam.numOfNodesToExplore)
)
case p => Failure(new IllegalArgumentException(s"Expected AnnoyRuntimeParams got $p"))
}
}
}
case class AnnoyRuntimeParams(
/* Number of vectors to evaluate while searching. A larger value will give more accurate results, but will take longer time to return.
* Default value would be numberOfTrees*numberOfNeigboursRequested
*/
nodesToExplore: Option[Int])
extends RuntimeParams {
override def toString: String = s"AnnoyRuntimeParams( nodesToExplore = $nodesToExplore)"
}

View File

@ -1,23 +0,0 @@
scala_library(
sources = ["*.scala"],
compiler_option_sets = ["fatal_warnings"],
platform = "java8",
tags = ["bazel-compatible"],
dependencies = [
"3rdparty/jvm/com/spotify:annoy-java",
"3rdparty/jvm/com/spotify:annoy-snapshot",
"3rdparty/jvm/com/twitter/storehaus:core",
"ann/src/main/scala/com/twitter/ann/common",
"ann/src/main/scala/com/twitter/ann/file_store",
"ann/src/main/thrift/com/twitter/ann/common:ann-common-scala",
"mediaservices/commons",
"src/java/com/twitter/search/common/file",
"src/scala/com/twitter/ml/api/embedding",
],
exports = [
"ann/src/main/scala/com/twitter/ann/common",
"src/java/com/twitter/common_internal/hadoop",
"src/java/com/twitter/search/common/file",
"src/scala/com/twitter/ml/api/embedding",
],
)

Binary file not shown.

View File

@ -1,123 +0,0 @@
package com.twitter.ann.annoy
import com.spotify.annoy.jni.base.{Annoy => AnnoyLib}
import com.twitter.ann.annoy.AnnoyCommon.IndexFileName
import com.twitter.ann.annoy.AnnoyCommon.MetaDataFileName
import com.twitter.ann.annoy.AnnoyCommon.MetadataCodec
import com.twitter.ann.common.EmbeddingType._
import com.twitter.ann.common._
import com.twitter.ann.common.thriftscala.AnnoyIndexMetadata
import com.twitter.concurrent.AsyncSemaphore
import com.twitter.mediaservices.commons.codec.ArrayByteBufferCodec
import com.twitter.search.common.file.AbstractFile
import com.twitter.search.common.file.LocalFile
import com.twitter.util.Future
import com.twitter.util.FuturePool
import java.io.File
import java.nio.file.Files
import org.apache.beam.sdk.io.fs.ResourceId
import scala.collection.JavaConverters._
private[annoy] object RawAnnoyIndexBuilder {
private[annoy] def apply[D <: Distance[D]](
dimension: Int,
numOfTrees: Int,
metric: Metric[D],
futurePool: FuturePool
): RawAppendable[AnnoyRuntimeParams, D] with Serialization = {
val indexBuilder = AnnoyLib.newIndex(dimension, annoyMetric(metric))
new RawAnnoyIndexBuilder(dimension, numOfTrees, metric, indexBuilder, futurePool)
}
private[this] def annoyMetric(metric: Metric[_]): AnnoyLib.Metric = {
metric match {
case L2 => AnnoyLib.Metric.EUCLIDEAN
case Cosine => AnnoyLib.Metric.ANGULAR
case _ => throw new RuntimeException("Not supported: " + metric)
}
}
}
private[this] class RawAnnoyIndexBuilder[D <: Distance[D]](
dimension: Int,
numOfTrees: Int,
metric: Metric[D],
indexBuilder: AnnoyLib.Builder,
futurePool: FuturePool)
extends RawAppendable[AnnoyRuntimeParams, D]
with Serialization {
private[this] var counter = 0
// Note: Only one thread can access the underlying index, multithreaded index building not supported
private[this] val semaphore = new AsyncSemaphore(1)
override def append(embedding: EmbeddingVector): Future[Long] =
semaphore.acquireAndRun({
counter += 1
indexBuilder.addItem(
counter,
embedding.toArray
.map(float => float2Float(float))
.toList
.asJava
)
Future.value(counter)
})
override def toQueryable: Queryable[Long, AnnoyRuntimeParams, D] = {
val tempDirParent = Files.createTempDirectory("raw_annoy_index").toFile
tempDirParent.deleteOnExit
val tempDir = new LocalFile(tempDirParent)
this.toDirectory(tempDir)
RawAnnoyQueryIndex(
dimension,
metric,
futurePool,
tempDir
)
}
override def toDirectory(directory: ResourceId): Unit = {
toDirectory(new IndexOutputFile(directory))
}
/**
* Serialize the annoy index in a directory.
* @param directory: Directory to save to.
*/
override def toDirectory(directory: AbstractFile): Unit = {
toDirectory(new IndexOutputFile(directory))
}
private def toDirectory(directory: IndexOutputFile): Unit = {
val indexFile = directory.createFile(IndexFileName)
saveIndex(indexFile)
val metaDataFile = directory.createFile(MetaDataFileName)
saveMetadata(metaDataFile)
}
private[this] def saveIndex(indexFile: IndexOutputFile): Unit = {
val index = indexBuilder
.build(numOfTrees)
val temp = new LocalFile(File.createTempFile(IndexFileName, null))
index.save(temp.getPath)
indexFile.copyFrom(temp.getByteSource.openStream())
temp.delete()
}
private[this] def saveMetadata(metadataFile: IndexOutputFile): Unit = {
val numberOfVectorsIndexed = counter
val metadata = AnnoyIndexMetadata(
dimension,
Metric.toThrift(metric),
numOfTrees,
numberOfVectorsIndexed
)
val bytes = ArrayByteBufferCodec.decode(MetadataCodec.encode(metadata))
val temp = new LocalFile(File.createTempFile(MetaDataFileName, null))
temp.getByteSink.write(bytes)
metadataFile.copyFrom(temp.getByteSource.openStream())
temp.delete()
}
}

View File

@ -1,142 +0,0 @@
package com.twitter.ann.annoy
import com.spotify.annoy.{ANNIndex, IndexType}
import com.twitter.ann.annoy.AnnoyCommon._
import com.twitter.ann.common._
import com.twitter.ann.common.EmbeddingType._
import com.twitter.mediaservices.commons.codec.ArrayByteBufferCodec
import com.twitter.search.common.file.{AbstractFile, LocalFile}
import com.twitter.util.{Future, FuturePool}
import java.io.File
import scala.collection.JavaConverters._
private[annoy] object RawAnnoyQueryIndex {
private[annoy] def apply[D <: Distance[D]](
dimension: Int,
metric: Metric[D],
futurePool: FuturePool,
directory: AbstractFile
): Queryable[Long, AnnoyRuntimeParams, D] = {
val metadataFile = directory.getChild(MetaDataFileName)
val indexFile = directory.getChild(IndexFileName)
val metadata = MetadataCodec.decode(
ArrayByteBufferCodec.encode(metadataFile.getByteSource.read())
)
val existingDimension = metadata.dimension
assert(
existingDimension == dimension,
s"Dimensions do not match. requested: $dimension existing: $existingDimension"
)
val existingMetric = Metric.fromThrift(metadata.distanceMetric)
assert(
existingMetric == metric,
s"DistanceMetric do not match. requested: $metric existing: $existingMetric"
)
val index = loadIndex(indexFile, dimension, annoyMetric(metric))
new RawAnnoyQueryIndex[D](
dimension,
metric,
metadata.numOfTrees,
index,
futurePool
)
}
private[this] def annoyMetric(metric: Metric[_]): IndexType = {
metric match {
case L2 => IndexType.EUCLIDEAN
case Cosine => IndexType.ANGULAR
case _ => throw new RuntimeException("Not supported: " + metric)
}
}
private[this] def loadIndex(
indexFile: AbstractFile,
dimension: Int,
indexType: IndexType
): ANNIndex = {
var localIndexFile = indexFile
// If not a local file copy to local, so that it can be memory mapped.
if (!indexFile.isInstanceOf[LocalFile]) {
val tempFile = File.createTempFile(IndexFileName, null)
tempFile.deleteOnExit()
val temp = new LocalFile(tempFile)
indexFile.copyTo(temp)
localIndexFile = temp
}
new ANNIndex(
dimension,
localIndexFile.getPath(),
indexType
)
}
}
private[this] class RawAnnoyQueryIndex[D <: Distance[D]](
dimension: Int,
metric: Metric[D],
numOfTrees: Int,
index: ANNIndex,
futurePool: FuturePool)
extends Queryable[Long, AnnoyRuntimeParams, D]
with AutoCloseable {
override def query(
embedding: EmbeddingVector,
numOfNeighbours: Int,
runtimeParams: AnnoyRuntimeParams
): Future[List[Long]] = {
queryWithDistance(embedding, numOfNeighbours, runtimeParams)
.map(_.map(_.neighbor))
}
override def queryWithDistance(
embedding: EmbeddingVector,
numOfNeighbours: Int,
runtimeParams: AnnoyRuntimeParams
): Future[List[NeighborWithDistance[Long, D]]] = {
futurePool {
val queryVector = embedding.toArray
val neigboursToRequest = neighboursToRequest(numOfNeighbours, runtimeParams)
val neigbours = index
.getNearestWithDistance(queryVector, neigboursToRequest)
.asScala
.take(numOfNeighbours)
.map { nn =>
val id = nn.getFirst.toLong
val distance = metric.fromAbsoluteDistance(nn.getSecond)
NeighborWithDistance(id, distance)
}
.toList
neigbours
}
}
// Annoy java lib do not expose param for numOfNodesToExplore.
// Default number is numOfTrees*numOfNeigbours.
// Simple hack is to artificially increase the numOfNeighbours to be requested and then just cap it before returning.
private[this] def neighboursToRequest(
numOfNeighbours: Int,
annoyParams: AnnoyRuntimeParams
): Int = {
annoyParams.nodesToExplore match {
case Some(nodesToExplore) => {
val neigboursToRequest = nodesToExplore / numOfTrees
if (neigboursToRequest < numOfNeighbours)
numOfNeighbours
else
neigboursToRequest
}
case _ => numOfNeighbours
}
}
// To close the memory map based file resource.
override def close(): Unit = index.close()
}

View File

@ -1,55 +0,0 @@
package com.twitter.ann.annoy
import com.twitter.ann.common._
import com.twitter.bijection.Injection
import com.twitter.search.common.file.AbstractFile
import com.twitter.util.FuturePool
// Class to provide Annoy based ann index.
object TypedAnnoyIndex {
/**
* Create Annoy based typed index builder that serializes index to a directory (HDFS/Local file system).
* It cannot be used in scalding as it leverage C/C++ jni bindings, whose build conflicts with version of some libs installed on hadoop.
* You can use it on aurora or with IndexBuilding job which triggers scalding job but then streams data to aurora machine for building index.
* @param dimension dimension of embedding
* @param numOfTrees builds a forest of numOfTrees trees.
* More trees gives higher precision when querying at the cost of increased memory and disk storage requirement at the build time.
* At runtime the index will be memory mapped, so memory wont be an issue but disk storage would be needed.
* @param metric distance metric for nearest neighbour search
* @param injection Injection to convert bytes to Id.
* @tparam T Type of Id for embedding
* @tparam D Typed Distance
* @return Serializable AnnoyIndex
*/
def indexBuilder[T, D <: Distance[D]](
dimension: Int,
numOfTrees: Int,
metric: Metric[D],
injection: Injection[T, Array[Byte]],
futurePool: FuturePool
): Appendable[T, AnnoyRuntimeParams, D] with Serialization = {
TypedAnnoyIndexBuilderWithFile(dimension, numOfTrees, metric, injection, futurePool)
}
/**
* Load Annoy based queryable index from a directory
* @param dimension dimension of embedding
* @param metric distance metric for nearest neighbour search
* @param injection Injection to convert bytes to Id.
* @param futurePool FuturePool
* @param directory Directory (HDFS/Local file system) where serialized index is stored.
* @tparam T Type of Id for embedding
* @tparam D Typed Distance
* @return Typed Queryable AnnoyIndex
*/
def loadQueryableIndex[T, D <: Distance[D]](
dimension: Int,
metric: Metric[D],
injection: Injection[T, Array[Byte]],
futurePool: FuturePool,
directory: AbstractFile
): Queryable[T, AnnoyRuntimeParams, D] = {
TypedAnnoyQueryIndexWithFile(dimension, metric, injection, futurePool, directory)
}
}

View File

@ -1,55 +0,0 @@
package com.twitter.ann.annoy
import com.twitter.ann.annoy.AnnoyCommon.IndexIdMappingFileName
import com.twitter.ann.common._
import com.twitter.ann.file_store.WritableIndexIdFileStore
import com.twitter.bijection.Injection
import com.twitter.search.common.file.AbstractFile
import com.twitter.util.Future
import com.twitter.util.FuturePool
import org.apache.beam.sdk.io.fs.ResourceId
private[annoy] object TypedAnnoyIndexBuilderWithFile {
private[annoy] def apply[T, D <: Distance[D]](
dimension: Int,
numOfTrees: Int,
metric: Metric[D],
injection: Injection[T, Array[Byte]],
futurePool: FuturePool
): Appendable[T, AnnoyRuntimeParams, D] with Serialization = {
val index = RawAnnoyIndexBuilder(dimension, numOfTrees, metric, futurePool)
val writableFileStore = WritableIndexIdFileStore(injection)
new TypedAnnoyIndexBuilderWithFile[T, D](index, writableFileStore)
}
}
private[this] class TypedAnnoyIndexBuilderWithFile[T, D <: Distance[D]](
indexBuilder: RawAppendable[AnnoyRuntimeParams, D] with Serialization,
store: WritableIndexIdFileStore[T])
extends Appendable[T, AnnoyRuntimeParams, D]
with Serialization {
private[this] val transformedIndex = IndexTransformer.transformAppendable(indexBuilder, store)
override def append(entity: EntityEmbedding[T]): Future[Unit] = {
transformedIndex.append(entity)
}
override def toDirectory(directory: ResourceId): Unit = {
indexBuilder.toDirectory(directory)
toDirectory(new IndexOutputFile(directory))
}
override def toDirectory(directory: AbstractFile): Unit = {
indexBuilder.toDirectory(directory)
toDirectory(new IndexOutputFile(directory))
}
private def toDirectory(directory: IndexOutputFile): Unit = {
val indexIdFile = directory.createFile(IndexIdMappingFileName)
store.save(indexIdFile)
}
override def toQueryable: Queryable[T, AnnoyRuntimeParams, D] = {
transformedIndex.toQueryable
}
}

View File

@ -1,42 +0,0 @@
package com.twitter.ann.annoy
import com.twitter.ann.annoy.AnnoyCommon._
import com.twitter.ann.common._
import com.twitter.ann.file_store.ReadableIndexIdFileStore
import com.twitter.bijection.Injection
import com.twitter.search.common.file.AbstractFile
import com.twitter.util.FuturePool
private[annoy] object TypedAnnoyQueryIndexWithFile {
private[annoy] def apply[T, D <: Distance[D]](
dimension: Int,
metric: Metric[D],
injection: Injection[T, Array[Byte]],
futurePool: FuturePool,
directory: AbstractFile
): Queryable[T, AnnoyRuntimeParams, D] = {
val deserializer =
new TypedAnnoyQueryIndexWithFile(dimension, metric, futurePool, injection)
deserializer.fromDirectory(directory)
}
}
private[this] class TypedAnnoyQueryIndexWithFile[T, D <: Distance[D]](
dimension: Int,
metric: Metric[D],
futurePool: FuturePool,
injection: Injection[T, Array[Byte]])
extends QueryableDeserialization[
T,
AnnoyRuntimeParams,
D,
Queryable[T, AnnoyRuntimeParams, D]
] {
override def fromDirectory(directory: AbstractFile): Queryable[T, AnnoyRuntimeParams, D] = {
val index = RawAnnoyQueryIndex(dimension, metric, futurePool, directory)
val indexIdFile = directory.getChild(IndexIdMappingFileName)
val readableFileStore = ReadableIndexIdFileStore(indexIdFile, injection)
IndexTransformer.transformQueryable(index, readableFileStore)
}
}

View File

@ -1,12 +0,0 @@
scala_library(
sources = ["*.scala"],
compiler_option_sets = ["fatal_warnings"],
platform = "java8",
tags = ["bazel-compatible"],
dependencies = [
"ann/src/main/scala/com/twitter/ann/common",
"ann/src/main/scala/com/twitter/ann/serialization",
"ann/src/main/thrift/com/twitter/ann/serialization:serialization-scala",
"src/java/com/twitter/search/common/file",
],
)

View File

@ -1,64 +0,0 @@
package com.twitter.ann.brute_force
import com.google.common.annotations.VisibleForTesting
import com.twitter.ann.common.{Distance, EntityEmbedding, Metric, QueryableDeserialization}
import com.twitter.ann.serialization.{PersistedEmbeddingInjection, ThriftIteratorIO}
import com.twitter.ann.serialization.thriftscala.PersistedEmbedding
import com.twitter.search.common.file.{AbstractFile, LocalFile}
import com.twitter.util.FuturePool
import java.io.File
/**
* @param factory creates a BruteForceIndex from the arguments. This is only exposed for testing.
* If for some reason you pass this arg in make sure that it eagerly consumes the
* iterator. If you don't you might close the input stream that the iterator is
* using.
* @tparam T the id of the embeddings
*/
class BruteForceDeserialization[T, D <: Distance[D]] @VisibleForTesting private[brute_force] (
metric: Metric[D],
embeddingInjection: PersistedEmbeddingInjection[T],
futurePool: FuturePool,
thriftIteratorIO: ThriftIteratorIO[PersistedEmbedding],
factory: (Metric[D], FuturePool, Iterator[EntityEmbedding[T]]) => BruteForceIndex[T, D])
extends QueryableDeserialization[T, BruteForceRuntimeParams.type, D, BruteForceIndex[T, D]] {
import BruteForceIndex._
def this(
metric: Metric[D],
embeddingInjection: PersistedEmbeddingInjection[T],
futurePool: FuturePool,
thriftIteratorIO: ThriftIteratorIO[PersistedEmbedding]
) = {
this(
metric,
embeddingInjection,
futurePool,
thriftIteratorIO,
factory = BruteForceIndex.apply[T, D]
)
}
override def fromDirectory(
serializationDirectory: AbstractFile
): BruteForceIndex[T, D] = {
val file = File.createTempFile(DataFileName, "tmp")
file.deleteOnExit()
val temp = new LocalFile(file)
val dataFile = serializationDirectory.getChild(DataFileName)
dataFile.copyTo(temp)
val inputStream = temp.getByteSource.openBufferedStream()
try {
val iterator: Iterator[PersistedEmbedding] = thriftIteratorIO.fromInputStream(inputStream)
val embeddings = iterator.map { thriftEmbedding =>
embeddingInjection.invert(thriftEmbedding).get
}
factory(metric, futurePool, embeddings)
} finally {
inputStream.close()
temp.delete()
}
}
}

View File

@ -1,162 +0,0 @@
package com.twitter.ann.brute_force
import com.twitter.ann.common.Appendable
import com.twitter.ann.common.Distance
import com.twitter.ann.common.EmbeddingType._
import com.twitter.ann.common.EntityEmbedding
import com.twitter.ann.common.IndexOutputFile
import com.twitter.ann.common.Metric
import com.twitter.ann.common.NeighborWithDistance
import com.twitter.ann.common.Queryable
import com.twitter.ann.common.RuntimeParams
import com.twitter.ann.common.Serialization
import com.twitter.ann.serialization.PersistedEmbeddingInjection
import com.twitter.ann.serialization.ThriftIteratorIO
import com.twitter.ann.serialization.thriftscala.PersistedEmbedding
import com.twitter.search.common.file.AbstractFile
import com.twitter.util.Future
import com.twitter.util.FuturePool
import java.util.concurrent.ConcurrentLinkedQueue
import org.apache.beam.sdk.io.fs.ResourceId
import scala.collection.JavaConverters._
import scala.collection.mutable
object BruteForceRuntimeParams extends RuntimeParams
object BruteForceIndex {
val DataFileName = "BruteForceFileData"
def apply[T, D <: Distance[D]](
metric: Metric[D],
futurePool: FuturePool,
initialEmbeddings: Iterator[EntityEmbedding[T]] = Iterator()
): BruteForceIndex[T, D] = {
val linkedQueue = new ConcurrentLinkedQueue[EntityEmbedding[T]]
initialEmbeddings.foreach(embedding => linkedQueue.add(embedding))
new BruteForceIndex(metric, futurePool, linkedQueue)
}
}
class BruteForceIndex[T, D <: Distance[D]] private (
metric: Metric[D],
futurePool: FuturePool,
// visible for serialization
private[brute_force] val linkedQueue: ConcurrentLinkedQueue[EntityEmbedding[T]])
extends Appendable[T, BruteForceRuntimeParams.type, D]
with Queryable[T, BruteForceRuntimeParams.type, D] {
override def append(embedding: EntityEmbedding[T]): Future[Unit] = {
futurePool {
linkedQueue.add(embedding)
}
}
override def toQueryable: Queryable[T, BruteForceRuntimeParams.type, D] = this
override def query(
embedding: EmbeddingVector,
numOfNeighbours: Int,
runtimeParams: BruteForceRuntimeParams.type
): Future[List[T]] = {
queryWithDistance(embedding, numOfNeighbours, runtimeParams).map { neighborsWithDistance =>
neighborsWithDistance.map(_.neighbor)
}
}
override def queryWithDistance(
embedding: EmbeddingVector,
numOfNeighbours: Int,
runtimeParams: BruteForceRuntimeParams.type
): Future[List[NeighborWithDistance[T, D]]] = {
futurePool {
// Use the reverse ordering so that we can call dequeue to remove the largest element.
val ordering = Ordering.by[NeighborWithDistance[T, D], D](_.distance)
val priorityQueue =
new mutable.PriorityQueue[NeighborWithDistance[T, D]]()(ordering)
linkedQueue
.iterator()
.asScala
.foreach { entity =>
val neighborWithDistance =
NeighborWithDistance(entity.id, metric.distance(entity.embedding, embedding))
priorityQueue.+=(neighborWithDistance)
if (priorityQueue.size > numOfNeighbours) {
priorityQueue.dequeue()
}
}
val reverseList: List[NeighborWithDistance[T, D]] =
priorityQueue.dequeueAll
reverseList.reverse
}
}
}
object SerializableBruteForceIndex {
def apply[T, D <: Distance[D]](
metric: Metric[D],
futurePool: FuturePool,
embeddingInjection: PersistedEmbeddingInjection[T],
thriftIteratorIO: ThriftIteratorIO[PersistedEmbedding]
): SerializableBruteForceIndex[T, D] = {
val bruteForceIndex = BruteForceIndex[T, D](metric, futurePool)
new SerializableBruteForceIndex(bruteForceIndex, embeddingInjection, thriftIteratorIO)
}
}
/**
* This is a class that wrapps a BruteForceIndex and provides a method for serialization.
*
* @param bruteForceIndex all queries and updates are sent to this index.
* @param embeddingInjection injection that can convert embeddings to thrift embeddings.
* @param thriftIteratorIO class that provides a way to write PersistedEmbeddings to disk
*/
class SerializableBruteForceIndex[T, D <: Distance[D]](
bruteForceIndex: BruteForceIndex[T, D],
embeddingInjection: PersistedEmbeddingInjection[T],
thriftIteratorIO: ThriftIteratorIO[PersistedEmbedding])
extends Appendable[T, BruteForceRuntimeParams.type, D]
with Queryable[T, BruteForceRuntimeParams.type, D]
with Serialization {
import BruteForceIndex._
override def append(entity: EntityEmbedding[T]): Future[Unit] =
bruteForceIndex.append(entity)
override def toQueryable: Queryable[T, BruteForceRuntimeParams.type, D] = this
override def query(
embedding: EmbeddingVector,
numOfNeighbours: Int,
runtimeParams: BruteForceRuntimeParams.type
): Future[List[T]] =
bruteForceIndex.query(embedding, numOfNeighbours, runtimeParams)
override def queryWithDistance(
embedding: EmbeddingVector,
numOfNeighbours: Int,
runtimeParams: BruteForceRuntimeParams.type
): Future[List[NeighborWithDistance[T, D]]] =
bruteForceIndex.queryWithDistance(embedding, numOfNeighbours, runtimeParams)
override def toDirectory(serializationDirectory: ResourceId): Unit = {
toDirectory(new IndexOutputFile(serializationDirectory))
}
override def toDirectory(serializationDirectory: AbstractFile): Unit = {
toDirectory(new IndexOutputFile(serializationDirectory))
}
private def toDirectory(serializationDirectory: IndexOutputFile): Unit = {
val outputStream = serializationDirectory.createFile(DataFileName).getOutputStream()
val thriftEmbeddings =
bruteForceIndex.linkedQueue.iterator().asScala.map { embedding =>
embeddingInjection(embedding)
}
try {
thriftIteratorIO.toOutputStream(thriftEmbeddings, outputStream)
} finally {
outputStream.close()
}
}
}

View File

@ -1,28 +0,0 @@
package com.twitter.ann.common
import com.twitter.bijection.{Bijection, Injection}
// Class providing commonly used injections that can be used directly with ANN apis.
// Injection prefixed with `J` can be used in java directly with ANN apis.
object AnnInjections {
val LongInjection: Injection[Long, Array[Byte]] = Injection.long2BigEndian
def StringInjection: Injection[String, Array[Byte]] = Injection.utf8
def IntInjection: Injection[Int, Array[Byte]] = Injection.int2BigEndian
val JLongInjection: Injection[java.lang.Long, Array[Byte]] =
Bijection.long2Boxed
.asInstanceOf[Bijection[Long, java.lang.Long]]
.inverse
.andThen(LongInjection)
val JStringInjection: Injection[java.lang.String, Array[Byte]] =
StringInjection
val JIntInjection: Injection[java.lang.Integer, Array[Byte]] =
Bijection.int2Boxed
.asInstanceOf[Bijection[Int, java.lang.Integer]]
.inverse
.andThen(IntInjection)
}

Binary file not shown.

View File

@ -1,150 +0,0 @@
package com.twitter.ann.common
import com.twitter.ann.common.EmbeddingType.EmbeddingVector
import com.twitter.ml.api.embedding.Embedding
import com.twitter.ml.api.embedding.EmbeddingMath
import com.twitter.ml.api.embedding.EmbeddingSerDe
import com.twitter.util.Future
object EmbeddingType {
type EmbeddingVector = Embedding[Float]
val embeddingSerDe = EmbeddingSerDe.apply[Float]
private[common] val math = EmbeddingMath.Float
}
/**
* Typed entity with an embedding associated with it.
* @param id : Unique Id for an entity.
* @param embedding : Embedding/Vector of an entity.
* @tparam T: Type of id.
*/
case class EntityEmbedding[T](id: T, embedding: EmbeddingVector)
// Query interface for ANN
trait Queryable[T, P <: RuntimeParams, D <: Distance[D]] {
/**
* ANN query for ids.
* @param embedding: Embedding/Vector to be queried with.
* @param numOfNeighbors: Number of neighbours to be queried for.
* @param runtimeParams: Runtime params associated with index to control accuracy/latency etc.
* @return List of approximate nearest neighbour ids.
*/
def query(
embedding: EmbeddingVector,
numOfNeighbors: Int,
runtimeParams: P
): Future[List[T]]
/**
* ANN query for ids with distance.
* @param embedding: Embedding/Vector to be queried with.
* @param numOfNeighbors: Number of neighbours to be queried for.
* @param runtimeParams: Runtime params associated with index to control accuracy/latency etc.
* @return List of approximate nearest neighbour ids with distance from the query embedding.
*/
def queryWithDistance(
embedding: EmbeddingVector,
numOfNeighbors: Int,
runtimeParams: P
): Future[List[NeighborWithDistance[T, D]]]
}
// Query interface for ANN over indexes that are grouped
trait QueryableGrouped[T, P <: RuntimeParams, D <: Distance[D]] extends Queryable[T, P, D] {
/**
* ANN query for ids.
* @param embedding: Embedding/Vector to be queried with.
* @param numOfNeighbors: Number of neighbours to be queried for.
* @param runtimeParams: Runtime params associated with index to control accuracy/latency etc.
* @param key: Optional key to lookup specific ANN index and perform query there
* @return List of approximate nearest neighbour ids.
*/
def query(
embedding: EmbeddingVector,
numOfNeighbors: Int,
runtimeParams: P,
key: Option[String]
): Future[List[T]]
/**
* ANN query for ids with distance.
* @param embedding: Embedding/Vector to be queried with.
* @param numOfNeighbors: Number of neighbours to be queried for.
* @param runtimeParams: Runtime params associated with index to control accuracy/latency etc.
* @param key: Optional key to lookup specific ANN index and perform query there
* @return List of approximate nearest neighbour ids with distance from the query embedding.
*/
def queryWithDistance(
embedding: EmbeddingVector,
numOfNeighbors: Int,
runtimeParams: P,
key: Option[String]
): Future[List[NeighborWithDistance[T, D]]]
}
/**
* Runtime params associated with index to control accuracy/latency etc while querying.
*/
trait RuntimeParams {}
/**
* ANN query result with distance.
* @param neighbor : Id of the neighbours
* @param distance: Distance of neighbour from query ex: D: CosineDistance, L2Distance, InnerProductDistance
*/
case class NeighborWithDistance[T, D <: Distance[D]](neighbor: T, distance: D)
/**
* ANN query result with seed entity for which this neighbor was provided.
* @param seed: Seed Id for which ann query was called
* @param neighbor : Id of the neighbours
*/
case class NeighborWithSeed[T1, T2](seed: T1, neighbor: T2)
/**
* ANN query result with distance with seed entity for which this neighbor was provided.
* @param seed: Seed Id for which ann query was called
* @param neighbor : Id of the neighbours
* @param distance: Distance of neighbour from query ex: D: CosineDistance, L2Distance, InnerProductDistance
*/
case class NeighborWithDistanceWithSeed[T1, T2, D <: Distance[D]](
seed: T1,
neighbor: T2,
distance: D)
trait RawAppendable[P <: RuntimeParams, D <: Distance[D]] {
/**
* Append an embedding in an index.
* @param embedding: Embedding/Vector
* @return Future of long id associated with embedding autogenerated.
*/
def append(embedding: EmbeddingVector): Future[Long]
/**
* Convert an Appendable to Queryable interface to query an index.
*/
def toQueryable: Queryable[Long, P, D]
}
// Index building interface for ANN.
trait Appendable[T, P <: RuntimeParams, D <: Distance[D]] {
/**
* Append an entity with embedding in an index.
* @param entity: Entity with its embedding
*/
def append(entity: EntityEmbedding[T]): Future[Unit]
/**
* Convert an Appendable to Queryable interface to query an index.
*/
def toQueryable: Queryable[T, P, D]
}
// Updatable index interface for ANN.
trait Updatable[T] {
def update(entity: EntityEmbedding[T]): Future[Unit]
}

View File

@ -1,21 +0,0 @@
scala_library(
sources = ["*.scala"],
compiler_option_sets = ["fatal_warnings"],
platform = "java8",
tags = ["bazel-compatible"],
dependencies = [
"3rdparty/jvm/com/google/guava",
"3rdparty/jvm/com/twitter/bijection:core",
"3rdparty/jvm/com/twitter/storehaus:core",
"3rdparty/jvm/org/apache/beam:beam-sdks-java-io-google-cloud-platform",
"ann/src/main/thrift/com/twitter/ann/common:ann-common-scala",
"finatra/inject/inject-mdc/src/main/scala",
"mediaservices/commons/src/main/scala:futuretracker",
"src/java/com/twitter/search/common/file",
"src/scala/com/twitter/ml/api/embedding",
"stitch/stitch-core",
],
exports = [
"3rdparty/jvm/com/twitter/bijection:core",
],
)

Binary file not shown.

View File

@ -1,13 +0,0 @@
package com.twitter.ann.common
import com.twitter.stitch.Stitch
trait EmbeddingProducer[T] {
/**
* Produce an embedding from type T. Implementations of this could do a lookup from an id to an
* embedding. Or they could run a deep model on features that output and embedding.
* @return An embedding Stitch. See go/stitch for details on how to use the Stitch API.
*/
def produceEmbedding(input: T): Stitch[Option[EmbeddingType.EmbeddingVector]]
}

View File

@ -1,226 +0,0 @@
package com.twitter.ann.common
import com.google.common.io.ByteStreams
import com.twitter.ann.common.thriftscala.AnnIndexMetadata
import com.twitter.mediaservices.commons.codec.ArrayByteBufferCodec
import com.twitter.mediaservices.commons.codec.ThriftByteBufferCodec
import com.twitter.search.common.file.AbstractFile
import java.io.IOException
import java.io.InputStream
import java.io.OutputStream
import java.nio.channels.Channels
import org.apache.beam.sdk.io.FileSystems
import org.apache.beam.sdk.io.fs.MoveOptions
import org.apache.beam.sdk.io.fs.ResolveOptions
import org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions
import org.apache.beam.sdk.io.fs.ResourceId
import org.apache.beam.sdk.util.MimeTypes
import org.apache.hadoop.io.IOUtils
import scala.collection.JavaConverters._
/**
* This class creates a wrapper around GCS filesystem and HDFS filesystem for the index
* generation job. It implements the basic methods required by the index generation job and hides
* the logic around handling HDFS vs GCS.
*/
class IndexOutputFile(val abstractFile: AbstractFile, val resourceId: ResourceId) {
// Success file name
private val SUCCESS_FILE = "_SUCCESS"
private val INDEX_METADATA_FILE = "ANN_INDEX_METADATA"
private val MetadataCodec = new ThriftByteBufferCodec[AnnIndexMetadata](AnnIndexMetadata)
/**
* Constructor for ResourceId. This is used for GCS filesystem
* @param resourceId
*/
def this(resourceId: ResourceId) = {
this(null, resourceId)
}
/**
* Constructor for AbstractFile. This is used for HDFS and local filesystem
* @param abstractFile
*/
def this(abstractFile: AbstractFile) = {
this(abstractFile, null)
}
/**
* Returns true if this instance is around an AbstractFile.
* @return
*/
def isAbstractFile(): Boolean = {
abstractFile != null
}
/**
* Creates a _SUCCESS file in the current directory.
*/
def createSuccessFile(): Unit = {
if (isAbstractFile()) {
abstractFile.createSuccessFile()
} else {
val successFile =
resourceId.resolve(SUCCESS_FILE, ResolveOptions.StandardResolveOptions.RESOLVE_FILE)
val successWriterChannel = FileSystems.create(successFile, MimeTypes.BINARY)
successWriterChannel.close()
}
}
/**
* Returns whether the current instance represents a directory
* @return True if the current instance is a directory
*/
def isDirectory(): Boolean = {
if (isAbstractFile()) {
abstractFile.isDirectory
} else {
resourceId.isDirectory
}
}
/**
* Return the current path of the file represented by the current instance
* @return The path string of the file/directory
*/
def getPath(): String = {
if (isAbstractFile()) {
abstractFile.getPath.toString
} else {
if (resourceId.isDirectory) {
resourceId.getCurrentDirectory.toString
} else {
resourceId.getCurrentDirectory.toString + resourceId.getFilename
}
}
}
/**
* Creates a new file @param fileName in the current directory.
* @param fileName
* @return A new file inside the current directory
*/
def createFile(fileName: String): IndexOutputFile = {
if (isAbstractFile()) {
// AbstractFile treats files and directories the same way. Hence, not checking for directory
// here.
new IndexOutputFile(abstractFile.getChild(fileName))
} else {
if (!resourceId.isDirectory) {
// If this is not a directory, throw exception.
throw new IllegalArgumentException(getPath() + " is not a directory.")
}
new IndexOutputFile(
resourceId.resolve(fileName, ResolveOptions.StandardResolveOptions.RESOLVE_FILE))
}
}
/**
* Creates a new directory @param directoryName in the current directory.
* @param directoryName
* @return A new directory inside the current directory
*/
def createDirectory(directoryName: String): IndexOutputFile = {
if (isAbstractFile()) {
// AbstractFile treats files and directories the same way. Hence, not checking for directory
// here.
val dir = abstractFile.getChild(directoryName)
dir.mkdirs()
new IndexOutputFile(dir)
} else {
if (!resourceId.isDirectory) {
// If this is not a directory, throw exception.
throw new IllegalArgumentException(getPath() + " is not a directory.")
}
val newResourceId =
resourceId.resolve(directoryName, ResolveOptions.StandardResolveOptions.RESOLVE_DIRECTORY)
// Create a tmp file and delete in order to trigger directory creation
val tmpFile =
newResourceId.resolve("tmp", ResolveOptions.StandardResolveOptions.RESOLVE_FILE)
val tmpWriterChannel = FileSystems.create(tmpFile, MimeTypes.BINARY)
tmpWriterChannel.close()
FileSystems.delete(List(tmpFile).asJava, MoveOptions.StandardMoveOptions.IGNORE_MISSING_FILES)
new IndexOutputFile(newResourceId)
}
}
def getChild(fileName: String, isDirectory: Boolean = false): IndexOutputFile = {
if (isAbstractFile()) {
new IndexOutputFile(abstractFile.getChild(fileName))
} else {
val resolveOption = if (isDirectory) {
StandardResolveOptions.RESOLVE_DIRECTORY
} else {
StandardResolveOptions.RESOLVE_FILE
}
new IndexOutputFile(resourceId.resolve(fileName, resolveOption))
}
}
/**
* Returns an OutputStream for the underlying file.
* Note: Close the OutputStream after writing
* @return
*/
def getOutputStream(): OutputStream = {
if (isAbstractFile()) {
abstractFile.getByteSink.openStream()
} else {
if (resourceId.isDirectory) {
// If this is a directory, throw exception.
throw new IllegalArgumentException(getPath() + " is a directory.")
}
val writerChannel = FileSystems.create(resourceId, MimeTypes.BINARY)
Channels.newOutputStream(writerChannel)
}
}
/**
* Returns an InputStream for the underlying file.
* Note: Close the InputStream after reading
* @return
*/
def getInputStream(): InputStream = {
if (isAbstractFile()) {
abstractFile.getByteSource.openStream()
} else {
if (resourceId.isDirectory) {
// If this is a directory, throw exception.
throw new IllegalArgumentException(getPath() + " is a directory.")
}
val readChannel = FileSystems.open(resourceId)
Channels.newInputStream(readChannel)
}
}
/**
* Copies content from the srcIn into the current file.
* @param srcIn
*/
def copyFrom(srcIn: InputStream): Unit = {
val out = getOutputStream()
try {
IOUtils.copyBytes(srcIn, out, 4096)
out.close()
} catch {
case ex: IOException =>
IOUtils.closeStream(out);
throw ex;
}
}
def writeIndexMetadata(annIndexMetadata: AnnIndexMetadata): Unit = {
val out = createFile(INDEX_METADATA_FILE).getOutputStream()
val bytes = ArrayByteBufferCodec.decode(MetadataCodec.encode(annIndexMetadata))
out.write(bytes)
out.close()
}
def loadIndexMetadata(): AnnIndexMetadata = {
val in = ByteStreams.toByteArray(getInputStream())
MetadataCodec.decode(ArrayByteBufferCodec.encode(in))
}
}

View File

@ -1,118 +0,0 @@
package com.twitter.ann.common
import com.twitter.ann.common.EmbeddingType.EmbeddingVector
import com.twitter.storehaus.{ReadableStore, Store}
import com.twitter.util.Future
// Utility to transform raw index to typed index using Store
object IndexTransformer {
/**
* Transform a long type queryable index to Typed queryable index
* @param index: Raw Queryable index
* @param store: Readable store to provide mappings between Long and T
* @tparam T: Type to transform to
* @tparam P: Runtime params
* @return Queryable index typed on T
*/
def transformQueryable[T, P <: RuntimeParams, D <: Distance[D]](
index: Queryable[Long, P, D],
store: ReadableStore[Long, T]
): Queryable[T, P, D] = {
new Queryable[T, P, D] {
override def query(
embedding: EmbeddingVector,
numOfNeighbors: Int,
runtimeParams: P
): Future[List[T]] = {
val neighbors = index.query(embedding, numOfNeighbors, runtimeParams)
neighbors
.flatMap(nn => {
val ids = nn.map(id => store.get(id).map(_.get))
Future
.collect(ids)
.map(_.toList)
})
}
override def queryWithDistance(
embedding: EmbeddingVector,
numOfNeighbors: Int,
runtimeParams: P
): Future[List[NeighborWithDistance[T, D]]] = {
val neighbors = index.queryWithDistance(embedding, numOfNeighbors, runtimeParams)
neighbors
.flatMap(nn => {
val ids = nn.map(obj =>
store.get(obj.neighbor).map(id => NeighborWithDistance(id.get, obj.distance)))
Future
.collect(ids)
.map(_.toList)
})
}
}
}
/**
* Transform a long type appendable index to Typed appendable index
* @param index: Raw Appendable index
* @param store: Writable store to store mappings between Long and T
* @tparam T: Type to transform to
* @return Appendable index typed on T
*/
def transformAppendable[T, P <: RuntimeParams, D <: Distance[D]](
index: RawAppendable[P, D],
store: Store[Long, T]
): Appendable[T, P, D] = {
new Appendable[T, P, D]() {
override def append(entity: EntityEmbedding[T]): Future[Unit] = {
index
.append(entity.embedding)
.flatMap(id => store.put((id, Some(entity.id))))
}
override def toQueryable: Queryable[T, P, D] = {
transformQueryable(index.toQueryable, store)
}
}
}
/**
* Transform a long type appendable and queryable index to Typed appendable and queryable index
* @param index: Raw Appendable and queryable index
* @param store: Store to provide/store mappings between Long and T
* @tparam T: Type to transform to
* @tparam Index: Index
* @return Appendable and queryable index typed on T
*/
def transform1[
Index <: RawAppendable[P, D] with Queryable[Long, P, D],
T,
P <: RuntimeParams,
D <: Distance[D]
](
index: Index,
store: Store[Long, T]
): Queryable[T, P, D] with Appendable[T, P, D] = {
val queryable = transformQueryable(index, store)
val appendable = transformAppendable(index, store)
new Queryable[T, P, D] with Appendable[T, P, D] {
override def query(
embedding: EmbeddingVector,
numOfNeighbors: Int,
runtimeParams: P
) = queryable.query(embedding, numOfNeighbors, runtimeParams)
override def queryWithDistance(
embedding: EmbeddingVector,
numOfNeighbors: Int,
runtimeParams: P
) = queryable.queryWithDistance(embedding, numOfNeighbors, runtimeParams)
override def append(entity: EntityEmbedding[T]) = appendable.append(entity)
override def toQueryable: Queryable[T, P, D] = appendable.toQueryable
}
}
}

View File

@ -1,37 +0,0 @@
package com.twitter.ann.common
import com.twitter.util.Return
import com.twitter.util.Throw
import com.twitter.util.Try
import com.twitter.util.logging.Logging
// Memoization with a twist
// New epoch reuse K:V pairs from previous and recycle everything else
class MemoizedInEpochs[K, V](f: K => Try[V]) extends Logging {
private var memoizedCalls: Map[K, V] = Map.empty
def epoch(keys: Seq[K]): Seq[V] = {
val newSet = keys.toSet
val keysToBeComputed = newSet.diff(memoizedCalls.keySet)
val computedKeysAndValues = keysToBeComputed.map { key =>
info(s"Memoize ${key}")
(key, f(key))
}
val keysAndValuesAfterFilteringFailures = computedKeysAndValues
.flatMap {
case (key, Return(value)) => Some((key, value))
case (key, Throw(e)) =>
warn(s"Calling f for ${key} has failed", e)
None
}
val keysReusedFromLastEpoch = memoizedCalls.filterKeys(newSet.contains)
memoizedCalls = keysReusedFromLastEpoch ++ keysAndValuesAfterFilteringFailures
debug(s"Final memoization is ${memoizedCalls.keys.mkString(", ")}")
keys.flatMap(memoizedCalls.get)
}
def currentEpochKeys: Set[K] = memoizedCalls.keySet
}

Binary file not shown.

View File

@ -1,290 +0,0 @@
package com.twitter.ann.common
import com.google.common.collect.ImmutableBiMap
import com.twitter.ann.common.EmbeddingType._
import com.twitter.ann.common.thriftscala.DistanceMetric
import com.twitter.ann.common.thriftscala.{CosineDistance => ServiceCosineDistance}
import com.twitter.ann.common.thriftscala.{Distance => ServiceDistance}
import com.twitter.ann.common.thriftscala.{InnerProductDistance => ServiceInnerProductDistance}
import com.twitter.ann.common.thriftscala.{EditDistance => ServiceEditDistance}
import com.twitter.ann.common.thriftscala.{L2Distance => ServiceL2Distance}
import com.twitter.bijection.Injection
import scala.util.Failure
import scala.util.Success
import scala.util.Try
// Ann distance metrics
trait Distance[D] extends Any with Ordered[D] {
def distance: Float
}
case class L2Distance(distance: Float) extends AnyVal with Distance[L2Distance] {
override def compare(that: L2Distance): Int =
Ordering.Float.compare(this.distance, that.distance)
}
case class CosineDistance(distance: Float) extends AnyVal with Distance[CosineDistance] {
override def compare(that: CosineDistance): Int =
Ordering.Float.compare(this.distance, that.distance)
}
case class InnerProductDistance(distance: Float)
extends AnyVal
with Distance[InnerProductDistance] {
override def compare(that: InnerProductDistance): Int =
Ordering.Float.compare(this.distance, that.distance)
}
case class EditDistance(distance: Float) extends AnyVal with Distance[EditDistance] {
override def compare(that: EditDistance): Int =
Ordering.Float.compare(this.distance, that.distance)
}
object Metric {
private[this] val thriftMetricMapping = ImmutableBiMap.of(
L2,
DistanceMetric.L2,
Cosine,
DistanceMetric.Cosine,
InnerProduct,
DistanceMetric.InnerProduct,
Edit,
DistanceMetric.EditDistance
)
def fromThrift(metric: DistanceMetric): Metric[_ <: Distance[_]] = {
thriftMetricMapping.inverse().get(metric)
}
def toThrift(metric: Metric[_ <: Distance[_]]): DistanceMetric = {
thriftMetricMapping.get(metric)
}
def fromString(metricName: String): Metric[_ <: Distance[_]]
with Injection[_, ServiceDistance] = {
metricName match {
case "Cosine" => Cosine
case "L2" => L2
case "InnerProduct" => InnerProduct
case "EditDistance" => Edit
case _ =>
throw new IllegalArgumentException(s"No Metric with the name $metricName")
}
}
}
sealed trait Metric[D <: Distance[D]] {
def distance(
embedding1: EmbeddingVector,
embedding2: EmbeddingVector
): D
def absoluteDistance(
embedding1: EmbeddingVector,
embedding2: EmbeddingVector
): Float
def fromAbsoluteDistance(distance: Float): D
}
case object L2 extends Metric[L2Distance] with Injection[L2Distance, ServiceDistance] {
override def distance(
embedding1: EmbeddingVector,
embedding2: EmbeddingVector
): L2Distance = {
fromAbsoluteDistance(MetricUtil.l2distance(embedding1, embedding2).toFloat)
}
override def fromAbsoluteDistance(distance: Float): L2Distance = {
L2Distance(distance)
}
override def absoluteDistance(
embedding1: EmbeddingVector,
embedding2: EmbeddingVector
): Float = distance(embedding1, embedding2).distance
override def apply(scalaDistance: L2Distance): ServiceDistance = {
ServiceDistance.L2Distance(ServiceL2Distance(scalaDistance.distance))
}
override def invert(serviceDistance: ServiceDistance): Try[L2Distance] = {
serviceDistance match {
case ServiceDistance.L2Distance(l2Distance) =>
Success(L2Distance(l2Distance.distance.toFloat))
case distance =>
Failure(new IllegalArgumentException(s"Expected an l2 distance but got $distance"))
}
}
}
case object Cosine extends Metric[CosineDistance] with Injection[CosineDistance, ServiceDistance] {
override def distance(
embedding1: EmbeddingVector,
embedding2: EmbeddingVector
): CosineDistance = {
fromAbsoluteDistance(1 - MetricUtil.cosineSimilarity(embedding1, embedding2))
}
override def fromAbsoluteDistance(distance: Float): CosineDistance = {
CosineDistance(distance)
}
override def absoluteDistance(
embedding1: EmbeddingVector,
embedding2: EmbeddingVector
): Float = distance(embedding1, embedding2).distance
override def apply(scalaDistance: CosineDistance): ServiceDistance = {
ServiceDistance.CosineDistance(ServiceCosineDistance(scalaDistance.distance))
}
override def invert(serviceDistance: ServiceDistance): Try[CosineDistance] = {
serviceDistance match {
case ServiceDistance.CosineDistance(cosineDistance) =>
Success(CosineDistance(cosineDistance.distance.toFloat))
case distance =>
Failure(new IllegalArgumentException(s"Expected a cosine distance but got $distance"))
}
}
}
case object InnerProduct
extends Metric[InnerProductDistance]
with Injection[InnerProductDistance, ServiceDistance] {
override def distance(
embedding1: EmbeddingVector,
embedding2: EmbeddingVector
): InnerProductDistance = {
fromAbsoluteDistance(1 - MetricUtil.dot(embedding1, embedding2))
}
override def fromAbsoluteDistance(distance: Float): InnerProductDistance = {
InnerProductDistance(distance)
}
override def absoluteDistance(
embedding1: EmbeddingVector,
embedding2: EmbeddingVector
): Float = distance(embedding1, embedding2).distance
override def apply(scalaDistance: InnerProductDistance): ServiceDistance = {
ServiceDistance.InnerProductDistance(ServiceInnerProductDistance(scalaDistance.distance))
}
override def invert(
serviceDistance: ServiceDistance
): Try[InnerProductDistance] = {
serviceDistance match {
case ServiceDistance.InnerProductDistance(cosineDistance) =>
Success(InnerProductDistance(cosineDistance.distance.toFloat))
case distance =>
Failure(
new IllegalArgumentException(s"Expected a inner product distance but got $distance")
)
}
}
}
case object Edit extends Metric[EditDistance] with Injection[EditDistance, ServiceDistance] {
private def intDistance(
embedding1: EmbeddingVector,
embedding2: EmbeddingVector,
pos1: Int,
pos2: Int,
precomputedDistances: scala.collection.mutable.Map[(Int, Int), Int]
): Int = {
// return the remaining characters of other String
if (pos1 == 0) return pos2
if (pos2 == 0) return pos1
// To check if the recursive tree
// for given n & m has already been executed
precomputedDistances.getOrElse(
(pos1, pos2), {
// We might want to change this so that capitals are considered the same.
// Also maybe some characters that look similar should also be the same.
val computed = if (embedding1(pos1 - 1) == embedding2(pos2 - 1)) {
intDistance(embedding1, embedding2, pos1 - 1, pos2 - 1, precomputedDistances)
} else { // If characters are nt equal, we need to
// find the minimum cost out of all 3 operations.
val insert = intDistance(embedding1, embedding2, pos1, pos2 - 1, precomputedDistances)
val del = intDistance(embedding1, embedding2, pos1 - 1, pos2, precomputedDistances)
val replace =
intDistance(embedding1, embedding2, pos1 - 1, pos2 - 1, precomputedDistances)
1 + Math.min(insert, Math.min(del, replace))
}
precomputedDistances.put((pos1, pos2), computed)
computed
}
)
}
override def distance(
embedding1: EmbeddingVector,
embedding2: EmbeddingVector
): EditDistance = {
val editDistance = intDistance(
embedding1,
embedding2,
embedding1.length,
embedding2.length,
scala.collection.mutable.Map[(Int, Int), Int]()
)
EditDistance(editDistance)
}
override def fromAbsoluteDistance(distance: Float): EditDistance = {
EditDistance(distance.toInt)
}
override def absoluteDistance(
embedding1: EmbeddingVector,
embedding2: EmbeddingVector
): Float = distance(embedding1, embedding2).distance
override def apply(scalaDistance: EditDistance): ServiceDistance = {
ServiceDistance.EditDistance(ServiceEditDistance(scalaDistance.distance.toInt))
}
override def invert(
serviceDistance: ServiceDistance
): Try[EditDistance] = {
serviceDistance match {
case ServiceDistance.EditDistance(cosineDistance) =>
Success(EditDistance(cosineDistance.distance.toFloat))
case distance =>
Failure(
new IllegalArgumentException(s"Expected a inner product distance but got $distance")
)
}
}
}
object MetricUtil {
private[ann] def dot(
embedding1: EmbeddingVector,
embedding2: EmbeddingVector
): Float = {
math.dotProduct(embedding1, embedding2)
}
private[ann] def l2distance(
embedding1: EmbeddingVector,
embedding2: EmbeddingVector
): Double = {
math.l2Distance(embedding1, embedding2)
}
private[ann] def cosineSimilarity(
embedding1: EmbeddingVector,
embedding2: EmbeddingVector
): Float = {
math.cosineSimilarity(embedding1, embedding2).toFloat
}
private[ann] def norm(
embedding: EmbeddingVector
): EmbeddingVector = {
math.normalize(embedding)
}
}

View File

@ -1,41 +0,0 @@
package com.twitter.ann.common
import com.twitter.stitch.Stitch
/**
* This is a trait that allows you to query for nearest neighbors given an arbitrary type T1. This is
* in contrast to a regular com.twitter.ann.common.Appendable, which takes an embedding as the input
* argument.
*
* This interface uses the Stitch API for batching. See go/stitch for details on how to use it.
*
* @tparam T1 type of the query.
* @tparam T2 type of the result.
* @tparam P runtime parameters supported by the index.
* @tparam D distance function used in the index.
*/
trait QueryableById[T1, T2, P <: RuntimeParams, D <: Distance[D]] {
def queryById(
id: T1,
numOfNeighbors: Int,
runtimeParams: P
): Stitch[List[T2]]
def queryByIdWithDistance(
id: T1,
numOfNeighbors: Int,
runtimeParams: P
): Stitch[List[NeighborWithDistance[T2, D]]]
def batchQueryById(
ids: Seq[T1],
numOfNeighbors: Int,
runtimeParams: P
): Stitch[List[NeighborWithSeed[T1, T2]]]
def batchQueryWithDistanceById(
ids: Seq[T1],
numOfNeighbors: Int,
runtimeParams: P
): Stitch[List[NeighborWithDistanceWithSeed[T1, T2, D]]]
}

View File

@ -1,91 +0,0 @@
package com.twitter.ann.common
import com.twitter.stitch.Stitch
/**
* Implementation of QueryableById that composes an EmbeddingProducer and a Queryable so that we
* can get nearest neighbors given an id of type T1
* @param embeddingProducer provides an embedding given an id.
* @param queryable provides a list of neighbors given an embedding.
* @tparam T1 type of the query.
* @tparam T2 type of the result.
* @tparam P runtime parameters supported by the index.
* @tparam D distance function used in the index.
*/
class QueryableByIdImplementation[T1, T2, P <: RuntimeParams, D <: Distance[D]](
embeddingProducer: EmbeddingProducer[T1],
queryable: Queryable[T2, P, D])
extends QueryableById[T1, T2, P, D] {
override def queryById(
id: T1,
numOfNeighbors: Int,
runtimeParams: P
): Stitch[List[T2]] = {
embeddingProducer.produceEmbedding(id).flatMap { embeddingOption =>
embeddingOption
.map { embedding =>
Stitch.callFuture(queryable.query(embedding, numOfNeighbors, runtimeParams))
}.getOrElse {
Stitch.value(List.empty)
}
}
}
override def queryByIdWithDistance(
id: T1,
numOfNeighbors: Int,
runtimeParams: P
): Stitch[List[NeighborWithDistance[T2, D]]] = {
embeddingProducer.produceEmbedding(id).flatMap { embeddingOption =>
embeddingOption
.map { embedding =>
Stitch.callFuture(queryable.queryWithDistance(embedding, numOfNeighbors, runtimeParams))
}.getOrElse {
Stitch.value(List.empty)
}
}
}
override def batchQueryById(
ids: Seq[T1],
numOfNeighbors: Int,
runtimeParams: P
): Stitch[List[NeighborWithSeed[T1, T2]]] = {
Stitch
.traverse(ids) { id =>
embeddingProducer.produceEmbedding(id).flatMap { embeddingOption =>
embeddingOption
.map { embedding =>
Stitch
.callFuture(queryable.query(embedding, numOfNeighbors, runtimeParams)).map(
_.map(neighbor => NeighborWithSeed(id, neighbor)))
}.getOrElse {
Stitch.value(List.empty)
}.handle { case _ => List.empty }
}
}.map { _.toList.flatten }
}
override def batchQueryWithDistanceById(
ids: Seq[T1],
numOfNeighbors: Int,
runtimeParams: P
): Stitch[List[NeighborWithDistanceWithSeed[T1, T2, D]]] = {
Stitch
.traverse(ids) { id =>
embeddingProducer.produceEmbedding(id).flatMap { embeddingOption =>
embeddingOption
.map { embedding =>
Stitch
.callFuture(queryable.queryWithDistance(embedding, numOfNeighbors, runtimeParams))
.map(_.map(neighbor =>
NeighborWithDistanceWithSeed(id, neighbor.neighbor, neighbor.distance)))
}.getOrElse {
Stitch.value(List.empty)
}.handle { case _ => List.empty }
}
}.map {
_.toList.flatten
}
}
}

View File

@ -1,26 +0,0 @@
package com.twitter.ann.common
import com.twitter.ann.common.EmbeddingType.EmbeddingVector
import com.twitter.util.Future
object QueryableOperations {
implicit class Map[T, P <: RuntimeParams, D <: Distance[D]](
val q: Queryable[T, P, D]) {
def mapRuntimeParameters(f: P => P): Queryable[T, P, D] = {
new Queryable[T, P, D] {
def query(
embedding: EmbeddingVector,
numOfNeighbors: Int,
runtimeParams: P
): Future[List[T]] = q.query(embedding, numOfNeighbors, f(runtimeParams))
def queryWithDistance(
embedding: EmbeddingVector,
numOfNeighbors: Int,
runtimeParams: P
): Future[List[NeighborWithDistance[T, D]]] =
q.queryWithDistance(embedding, numOfNeighbors, f(runtimeParams))
}
}
}
}

View File

@ -1,29 +0,0 @@
package com.twitter.ann.common
import com.google.common.annotations.VisibleForTesting
import com.twitter.util.{Future, FuturePool}
trait ReadWriteFuturePool {
def read[T](f: => T): Future[T]
def write[T](f: => T): Future[T]
}
object ReadWriteFuturePool {
def apply(readPool: FuturePool, writePool: FuturePool): ReadWriteFuturePool = {
new ReadWriteFuturePoolANN(readPool, writePool)
}
def apply(commonPool: FuturePool): ReadWriteFuturePool = {
new ReadWriteFuturePoolANN(commonPool, commonPool)
}
}
@VisibleForTesting
private[ann] class ReadWriteFuturePoolANN(readPool: FuturePool, writePool: FuturePool)
extends ReadWriteFuturePool {
def read[T](f: => T): Future[T] = {
readPool.apply(f)
}
def write[T](f: => T): Future[T] = {
writePool.apply(f)
}
}

View File

@ -1,28 +0,0 @@
package com.twitter.ann.common
import com.twitter.search.common.file.AbstractFile
import org.apache.beam.sdk.io.fs.ResourceId
/**
* Interface for writing an Appendable to a directory.
*/
trait Serialization {
def toDirectory(
serializationDirectory: AbstractFile
): Unit
def toDirectory(
serializationDirectory: ResourceId
): Unit
}
/**
* Interface for reading a Queryable from a directory
* @tparam T the id of the embeddings
* @tparam Q type of the Queryable that is deserialized.
*/
trait QueryableDeserialization[T, P <: RuntimeParams, D <: Distance[D], Q <: Queryable[T, P, D]] {
def fromDirectory(
serializationDirectory: AbstractFile
): Q
}

View File

@ -1,64 +0,0 @@
package com.twitter.ann.common
import com.twitter.ann.common.EmbeddingType._
import com.twitter.ann.common.thriftscala.{
NearestNeighborQuery,
NearestNeighborResult,
Distance => ServiceDistance,
RuntimeParams => ServiceRuntimeParams
}
import com.twitter.bijection.Injection
import com.twitter.finagle.Service
import com.twitter.mediaservices.commons.codec.ArrayByteBufferCodec
import com.twitter.util.Future
class ServiceClientQueryable[T, P <: RuntimeParams, D <: Distance[D]](
service: Service[NearestNeighborQuery, NearestNeighborResult],
runtimeParamInjection: Injection[P, ServiceRuntimeParams],
distanceInjection: Injection[D, ServiceDistance],
idInjection: Injection[T, Array[Byte]])
extends Queryable[T, P, D] {
override def query(
embedding: EmbeddingVector,
numOfNeighbors: Int,
runtimeParams: P
): Future[List[T]] = {
service
.apply(
NearestNeighborQuery(
embeddingSerDe.toThrift(embedding),
withDistance = false,
runtimeParamInjection(runtimeParams),
numOfNeighbors
)
)
.map { result =>
result.nearestNeighbors.map { nearestNeighbor =>
idInjection.invert(ArrayByteBufferCodec.decode(nearestNeighbor.id)).get
}.toList
}
}
override def queryWithDistance(
embedding: EmbeddingVector,
numOfNeighbors: Int,
runtimeParams: P
): Future[List[NeighborWithDistance[T, D]]] =
service
.apply(
NearestNeighborQuery(
embeddingSerDe.toThrift(embedding),
withDistance = true,
runtimeParamInjection(runtimeParams),
numOfNeighbors
)
)
.map { result =>
result.nearestNeighbors.map { nearestNeighbor =>
NeighborWithDistance(
idInjection.invert(ArrayByteBufferCodec.decode(nearestNeighbor.id)).get,
distanceInjection.invert(nearestNeighbor.distance.get).get
)
}.toList
}
}

View File

@ -1,87 +0,0 @@
package com.twitter.ann.common
import com.twitter.ann.common.EmbeddingType.EmbeddingVector
import com.twitter.util.Future
import scala.util.Random
trait ShardFunction[T] {
/**
* Shard function to shard embedding based on total shards and embedding data.
* @param shards
* @param entity
* @return Shard index, from 0(Inclusive) to shards(Exclusive))
*/
def apply(shards: Int, entity: EntityEmbedding[T]): Int
}
/**
* Randomly shards the embeddings based on number of total shards.
*/
class RandomShardFunction[T] extends ShardFunction[T] {
def apply(shards: Int, entity: EntityEmbedding[T]): Int = {
Random.nextInt(shards)
}
}
/**
* Sharded appendable to shard the embedding into different appendable indices
* @param indices: Sequence of appendable indices
* @param shardFn: Shard function to shard data into different indices
* @param shards: Total shards
* @tparam T: Type of id.
*/
class ShardedAppendable[T, P <: RuntimeParams, D <: Distance[D]](
indices: Seq[Appendable[T, P, D]],
shardFn: ShardFunction[T],
shards: Int)
extends Appendable[T, P, D] {
override def append(entity: EntityEmbedding[T]): Future[Unit] = {
val shard = shardFn(shards, entity)
val index = indices(shard)
index.append(entity)
}
override def toQueryable: Queryable[T, P, D] = {
new ComposedQueryable[T, P, D](indices.map(_.toQueryable))
}
}
/**
* Composition of sequence of queryable indices, it queries all the indices,
* and merges the result in memory to return the K nearest neighbours
* @param indices: Sequence of queryable indices
* @tparam T: Type of id
* @tparam P: Type of runtime param
* @tparam D: Type of distance metric
*/
class ComposedQueryable[T, P <: RuntimeParams, D <: Distance[D]](
indices: Seq[Queryable[T, P, D]])
extends Queryable[T, P, D] {
private[this] val ordering =
Ordering.by[NeighborWithDistance[T, D], D](_.distance)
override def query(
embedding: EmbeddingVector,
numOfNeighbors: Int,
runtimeParams: P
): Future[List[T]] = {
val neighbours = queryWithDistance(embedding, numOfNeighbors, runtimeParams)
neighbours.map(list => list.map(nn => nn.neighbor))
}
override def queryWithDistance(
embedding: EmbeddingVector,
numOfNeighbors: Int,
runtimeParams: P
): Future[List[NeighborWithDistance[T, D]]] = {
val futures = Future.collect(
indices.map(index => index.queryWithDistance(embedding, numOfNeighbors, runtimeParams))
)
futures.map { list =>
list.flatten
.sorted(ordering)
.take(numOfNeighbors)
.toList
}
}
}

Some files were not shown because too many files have changed in this diff Show More