mirror of
https://github.com/twitter/the-algorithm.git
synced 2024-06-13 14:48:54 +02:00
Optimize BatchPredictionResponse class
This commit is contained in:
parent
90d7ea370e
commit
d271ca7e2e
|
@ -37,8 +37,8 @@ BatchPredictionResponse::BatchPredictionResponse(
|
|||
std::vector<uint64_t> batch_sizes;
|
||||
batch_sizes.reserve(dense_values_.size());
|
||||
|
||||
for (int i = 0; i < dense_values_.size(); i++)
|
||||
batch_sizes.push_back(dense_values_.at(i).getDim(0));
|
||||
for (const auto& value : dense_values_)
|
||||
batch_sizes.emplace_back(value.getDim(0));
|
||||
|
||||
if (std::adjacent_find(
|
||||
batch_sizes.begin(),
|
||||
|
@ -76,14 +76,17 @@ void BatchPredictionResponse::serializePredictions(twml::ThriftWriter &thrift_wr
|
|||
thrift_writer.writeStructFieldHeader(TTYPE_LIST, BPR_PREDICTIONS);
|
||||
thrift_writer.writeListHeader(TTYPE_STRUCT, getBatchSize());
|
||||
|
||||
for (int i = 0; i < getBatchSize(); i++) {
|
||||
auto batchSize = getBatchSize();
|
||||
auto predictionSize = getPredictionSize();
|
||||
|
||||
for (int i = 0; i < batchSize; i++) {
|
||||
twml::DataRecord record = twml::DataRecord();
|
||||
|
||||
if (hasContinuous()) {
|
||||
const T *values = values_.getData<T>();
|
||||
const int64_t *local_keys = keys_.getData<int64_t>();
|
||||
const T *local_values = values + (i * getPredictionSize());
|
||||
record.addContinuous(local_keys, getPredictionSize(), local_values);
|
||||
const T *local_values = values + (i * predictionSize);
|
||||
record.addContinuous(local_keys, predictionSize, local_values);
|
||||
}
|
||||
|
||||
if (hasDenseTensors()) {
|
||||
|
|
Loading…
Reference in New Issue
Block a user