mirror of
https://github.com/twitter/the-algorithm.git
synced 2024-06-30 06:56:08 +02:00
33 lines
806 B
Scala
33 lines
806 B
Scala
|
package com.twitter.home_mixer.util
|
||
|
|
||
|
import com.twitter.ml.api.thriftscala.FloatTensor
|
||
|
import com.twitter.ml.api.util.BufferToIterators.RichFloatBuffer
|
||
|
import java.nio.ByteBuffer
|
||
|
import java.nio.ByteOrder
|
||
|
|
||
|
/**
|
||
|
* Contains functionality to transform data records and Tensors
|
||
|
*/
|
||
|
|
||
|
object TensorFlowUtil {
|
||
|
|
||
|
private def skipEmbeddingBBHeader(bb: ByteBuffer): ByteBuffer = {
|
||
|
val bb_copy = bb.duplicate()
|
||
|
bb_copy.getLong()
|
||
|
bb_copy
|
||
|
}
|
||
|
|
||
|
private def byteBufferToFloatIterator(
|
||
|
bb: ByteBuffer
|
||
|
): Iterator[Float] = {
|
||
|
bb.order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer.iterator
|
||
|
}
|
||
|
|
||
|
def embeddingByteBufferToFloatTensor(
|
||
|
bb: ByteBuffer
|
||
|
): FloatTensor = {
|
||
|
val bb_content = skipEmbeddingBBHeader(bb)
|
||
|
FloatTensor(byteBufferToFloatIterator(bb_content).map(_.toDouble).toList)
|
||
|
}
|
||
|
}
|