the-algorithm/twml/libtwml/include/twml/TensorRecordReader.h
twitter-team ef4c5eb65e Twitter Recommendation Algorithm
Please note we have force-pushed a new initial commit in order to remove some publicly-available Twitter user information. Note that this process may be required in the future.
2023-03-31 17:36:31 -05:00

35 lines
806 B
C++

#pragma once
#ifdef __cplusplus
#include <twml/defines.h>
#include <twml/TensorRecord.h>
#include <twml/ThriftReader.h>
#include <cstdint>
#include <vector>
#include <string>
#include <unordered_map>
namespace twml {
// Class that parses the thrift objects as defined in tensor.thrift
class TWMLAPI TensorRecordReader : public ThriftReader {
std::vector<uint64_t> readShape();
template<typename T> RawTensor readTypedTensor();
RawTensor readRawTypedTensor();
RawTensor readStringTensor();
RawTensor readGeneralTensor();
RawSparseTensor readCOOSparseTensor();
public:
void readTensor(const int feature_type, TensorRecord *record);
void readSparseTensor(const int feature_type, TensorRecord *record);
TensorRecordReader(const uint8_t *buffer) : ThriftReader(buffer) {}
};
}
#endif