Optimize BatchPredictionResponse class

This commit is contained in:
Tarek Abdellatef 2023-05-15 11:55:45 +02:00
parent 90d7ea370e
commit d271ca7e2e

View File

@ -37,8 +37,8 @@ BatchPredictionResponse::BatchPredictionResponse(
std::vector<uint64_t> batch_sizes;
batch_sizes.reserve(dense_values_.size());
for (int i = 0; i < dense_values_.size(); i++)
batch_sizes.push_back(dense_values_.at(i).getDim(0));
for (const auto& value : dense_values_)
batch_sizes.emplace_back(value.getDim(0));
if (std::adjacent_find(
batch_sizes.begin(),
@ -76,14 +76,17 @@ void BatchPredictionResponse::serializePredictions(twml::ThriftWriter &thrift_wr
thrift_writer.writeStructFieldHeader(TTYPE_LIST, BPR_PREDICTIONS);
thrift_writer.writeListHeader(TTYPE_STRUCT, getBatchSize());
for (int i = 0; i < getBatchSize(); i++) {
auto batchSize = getBatchSize();
auto predictionSize = getPredictionSize();
for (int i = 0; i < batchSize; i++) {
twml::DataRecord record = twml::DataRecord();
if (hasContinuous()) {
const T *values = values_.getData<T>();
const int64_t *local_keys = keys_.getData<int64_t>();
const T *local_values = values + (i * getPredictionSize());
record.addContinuous(local_keys, getPredictionSize(), local_values);
const T *local_values = values + (i * predictionSize);
record.addContinuous(local_keys, predictionSize, local_values);
}
if (hasDenseTensors()) {