-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathquantmatrix.cc
117 lines (102 loc) · 3.29 KB
/
quantmatrix.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "quantmatrix.h"
#include <assert.h>
#include <iostream>
#include <stdexcept>
namespace fasttext {
QuantMatrix::QuantMatrix() : Matrix(), qnorm_(false), codesize_(0) {}
QuantMatrix::QuantMatrix(DenseMatrix&& mat, int32_t dsub, bool qnorm)
: Matrix(mat.size(0), mat.size(1)),
qnorm_(qnorm),
codesize_(mat.size(0) * ((mat.size(1) + dsub - 1) / dsub)) {
codes_.resize(codesize_);
pq_ = std::unique_ptr<ProductQuantizer>(new ProductQuantizer(n_, dsub));
if (qnorm_) {
norm_codes_.resize(m_);
npq_ = std::unique_ptr<ProductQuantizer>(new ProductQuantizer(1, 1));
}
quantize(std::forward<DenseMatrix>(mat));
}
void QuantMatrix::quantizeNorm(const Vector& norms) {
assert(qnorm_);
assert(norms.size() == m_);
auto dataptr = norms.data();
npq_->train(m_, dataptr);
npq_->compute_codes(dataptr, norm_codes_.data(), m_);
}
void QuantMatrix::quantize(DenseMatrix&& mat) {
if (qnorm_) {
Vector norms(mat.size(0));
mat.l2NormRow(norms);
mat.divideRow(norms);
quantizeNorm(norms);
}
auto dataptr = mat.data();
pq_->train(m_, dataptr);
pq_->compute_codes(dataptr, codes_.data(), m_);
}
real QuantMatrix::dotRow(const Vector& vec, int64_t i) const {
assert(i >= 0);
assert(i < m_);
assert(vec.size() == n_);
real norm = 1;
if (qnorm_) {
norm = npq_->get_centroids(0, norm_codes_[i])[0];
}
return pq_->mulcode(vec, codes_.data(), i, norm);
}
void QuantMatrix::addVectorToRow(const Vector&, int64_t, real) {
throw std::runtime_error("Operation not permitted on quantized matrices.");
}
void QuantMatrix::addRowToVector(Vector& x, int32_t i, real a) const {
real norm = 1;
if (qnorm_) {
norm = npq_->get_centroids(0, norm_codes_[i])[0];
}
pq_->addcode(x, codes_.data(), i, a * norm);
}
void QuantMatrix::addRowToVector(Vector& x, int32_t i) const {
real norm = 1;
if (qnorm_) {
norm = npq_->get_centroids(0, norm_codes_[i])[0];
}
pq_->addcode(x, codes_.data(), i, norm);
}
void QuantMatrix::save(std::ostream& out) const {
out.write((char*)&qnorm_, sizeof(qnorm_));
out.write((char*)&m_, sizeof(m_));
out.write((char*)&n_, sizeof(n_));
out.write((char*)&codesize_, sizeof(codesize_));
out.write((char*)codes_.data(), codesize_ * sizeof(uint8_t));
pq_->save(out);
if (qnorm_) {
out.write((char*)norm_codes_.data(), m_ * sizeof(uint8_t));
npq_->save(out);
}
}
void QuantMatrix::load(std::istream& in) {
in.read((char*)&qnorm_, sizeof(qnorm_));
in.read((char*)&m_, sizeof(m_));
in.read((char*)&n_, sizeof(n_));
in.read((char*)&codesize_, sizeof(codesize_));
codes_ = std::vector<uint8_t>(codesize_);
in.read((char*)codes_.data(), codesize_ * sizeof(uint8_t));
pq_ = std::unique_ptr<ProductQuantizer>(new ProductQuantizer());
pq_->load(in);
if (qnorm_) {
norm_codes_ = std::vector<uint8_t>(m_);
in.read((char*)norm_codes_.data(), m_ * sizeof(uint8_t));
npq_ = std::unique_ptr<ProductQuantizer>(new ProductQuantizer());
npq_->load(in);
}
}
void QuantMatrix::dump(std::ostream&) const {
throw std::runtime_error("Operation not permitted on quantized matrices.");
}
} // namespace fasttext