mirror of
https://github.com/twitter/the-algorithm.git
synced 2025-01-07 09:58:17 +01:00
ef4c5eb65e
Please note we have force-pushed a new initial commit in order to remove some publicly-available Twitter user information. Note that this process may be required in the future.
93 lines
2.4 KiB
C++
93 lines
2.4 KiB
C++
#pragma once
|
|
#include <twml/Tensor.h>
|
|
#include <type_traits>
|
|
|
|
#ifdef __cplusplus
|
|
namespace twml {
|
|
|
|
// This class contains the raw pointers to tensors coming from thrift object.
|
|
class TWMLAPI RawTensor : public Tensor
|
|
{
|
|
private:
|
|
bool m_is_big_endian;
|
|
uint64_t m_raw_length;
|
|
public:
|
|
|
|
RawTensor() {}
|
|
|
|
RawTensor(void *data, const std::vector<uint64_t> &dims,
|
|
const std::vector<uint64_t> &strides, twml_type type, bool is_big_endian, uint64_t length)
|
|
: Tensor(data, dims, strides, type), m_is_big_endian(is_big_endian), m_raw_length(length) {}
|
|
|
|
bool is_big_endian() const {
|
|
return m_is_big_endian;
|
|
}
|
|
|
|
uint64_t getRawLength() const {
|
|
return m_raw_length;
|
|
}
|
|
|
|
// Extracts a slice from a tensor at idx0 along dimension 0
|
|
// Used in BatchPredictionResponse to write each slice in separate records
|
|
RawTensor getSlice(uint64_t idx0) const {
|
|
void *slice = nullptr;
|
|
uint64_t raw_length = 0;
|
|
|
|
if (getType() == TWML_TYPE_STRING) {
|
|
raw_length = getStride(0);
|
|
std::string *data = const_cast<std::string *>(static_cast<const std::string*>(getData<void>()));
|
|
slice = static_cast<void *>(data + raw_length * idx0);
|
|
} else {
|
|
raw_length = getStride(0) * getSizeOf(getType());
|
|
char *data = const_cast<char *>(static_cast<const char*>(getData<void>()));
|
|
slice = static_cast<void *>(data + raw_length * idx0);
|
|
}
|
|
|
|
std::vector<uint64_t> dims, strides;
|
|
for (int i = 1; i < getNumDims(); i++) {
|
|
dims.push_back(getDim(i));
|
|
strides.push_back(getStride(i));
|
|
}
|
|
|
|
return RawTensor(slice, dims, strides, getType(), m_is_big_endian, raw_length);
|
|
}
|
|
};
|
|
|
|
// Wrapper class around RawTensor to hold sparse tensors.
|
|
class TWMLAPI RawSparseTensor
|
|
{
|
|
private:
|
|
RawTensor m_indices;
|
|
RawTensor m_values;
|
|
std::vector<uint64_t> m_dense_shape;
|
|
|
|
public:
|
|
|
|
RawSparseTensor() {
|
|
}
|
|
|
|
RawSparseTensor(const RawTensor &indices_, const RawTensor &values_,
|
|
const std::vector<uint64_t> &dense_shape_) :
|
|
m_indices(indices_), m_values(values_), m_dense_shape(dense_shape_)
|
|
{
|
|
if (m_indices.getType() != TWML_TYPE_INT64) {
|
|
throw twml::Error(TWML_ERR_TYPE, "Indices of Sparse Tensor must be of type int64");
|
|
}
|
|
}
|
|
|
|
const RawTensor &indices() const {
|
|
return m_indices;
|
|
}
|
|
|
|
const RawTensor &values() const {
|
|
return m_values;
|
|
}
|
|
|
|
const std::vector<uint64_t>& denseShape() const {
|
|
return m_dense_shape;
|
|
}
|
|
};
|
|
|
|
}
|
|
#endif
|