Skip to content

Commit 6b41eda

Browse files
AnthonyBlinkop
authored andcommitted
Tensor template class created
1 parent 149878b commit 6b41eda

File tree

1 file changed

+304
-0
lines changed

1 file changed

+304
-0
lines changed
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
#ifndef CSLIB_DATA_STRUCTURE_TENSOR_HPP
2+
#define CSLIB_DATA_STRUCTURE_TENSOR_HPP
3+
4+
#include <cstdint>
5+
#include <string>
6+
#include <vector>
7+
#include <numeric>
8+
9+
namespace cslib {
10+
namespace data_structure
11+
{
12+
/* Class representing tensor object.
13+
* TODO:
14+
* -broadcasting
15+
* -range cutting
16+
* -linear algebra operations
17+
* -dimension operations (expand, reshape, ...)
18+
*/
19+
20+
21+
template <typename T>
22+
class Tensor
23+
{
24+
public:
25+
using dimensions = std::vector<size_t>;
26+
27+
private:
28+
T* _data = nullptr;
29+
int _number_elements;
30+
dimensions _dims;
31+
32+
bool _is_shallow_copy;
33+
bool _is_sub_tensor;
34+
35+
public:
36+
// Constructors
37+
Tensor(const dimensions& dims) :
38+
_is_shallow_copy(false),
39+
_is_sub_tensor(false),
40+
_dims(dims)
41+
{
42+
_number_elements = std::accumulate(_dims.begin(), _dims.end(), 1, std::multiplies<size_t>());
43+
_data = new T[_number_elements];
44+
}
45+
46+
Tensor(const Tensor& rhs)
47+
{
48+
data_release();
49+
50+
_data = rhs._data;
51+
_number_elements = rhs._number_elements;
52+
_dims = rhs._dims;
53+
_is_sub_tensor = rhs._is_sub_tensor;
54+
_is_shallow_copy = true;
55+
}
56+
57+
Tensor(Tensor&& rhs) noexcept
58+
{
59+
data_release();
60+
61+
_data = rhs._data;
62+
_number_elements = rhs._number_elements;
63+
_dims = rhs._dims;
64+
_is_sub_tensor = rhs._is_sub_tensor;
65+
_is_shallow_copy = rhs._is_shallow_copy;
66+
67+
rhs._data = nullptr;
68+
rhs._number_elements = 0;
69+
rhs._dims = {};
70+
rhs._is_shallow_copy = true;
71+
rhs._is_sub_tensor = false;
72+
}
73+
74+
// Destructor
75+
~Tensor() noexcept
76+
{
77+
data_release();
78+
}
79+
80+
// Operators overloading
81+
Tensor& operator=(const Tensor& rhs)
82+
{
83+
if (_is_shallow_copy && _is_sub_tensor) // copying sub tensor
84+
{
85+
// assert (_dims == rhs._dims)?
86+
if (_dims == rhs._dims)
87+
for (auto i = 0; i < _number_elements; i++)
88+
_data[i] = rhs._data[i];
89+
}
90+
else // tensor is not sub tensor
91+
{
92+
data_release();
93+
94+
_data = rhs._data;
95+
_number_elements = rhs._number_elements;
96+
_dims = rhs._dims;
97+
_is_shallow_copy = true;
98+
}
99+
100+
return *this;
101+
}
102+
103+
Tensor& operator=(Tensor&& rhs) noexcept
104+
{
105+
if (_is_shallow_copy && _is_sub_tensor) // copying sub tensor
106+
{
107+
// assert (_dims == rhs._dims)?
108+
if (_dims == rhs._dims)
109+
for (auto i = 0; i < _number_elements; i++)
110+
_data[i] = rhs._data[i];
111+
}
112+
else // tensor is not sub tensor
113+
{
114+
data_release();
115+
116+
_data = rhs._data;
117+
_number_elements = rhs._number_elements;
118+
_dims = rhs._dims;
119+
_is_shallow_copy = true;
120+
}
121+
122+
rhs._data = nullptr;
123+
rhs._number_elements = 0;
124+
rhs._dims = {};
125+
rhs._is_shallow_copy = true;
126+
rhs._is_sub_tensor = false;
127+
128+
return *this;
129+
}
130+
131+
Tensor& operator=(T value)
132+
{
133+
for (auto i = 0; i < _number_elements; i++)
134+
_data[i] = value;
135+
136+
return *this;
137+
}
138+
139+
Tensor operator+(T value) const
140+
{
141+
Tensor result = make_copy();
142+
143+
for (auto i = 0; i < _number_elements; i++)
144+
result._data[i] = _data[i] + value;
145+
146+
return result;
147+
}
148+
149+
Tensor operator+(const Tensor& rhs) const
150+
{
151+
Tensor result = make_copy();
152+
153+
if (rhs._dims == _dims) // element-wise addition
154+
{
155+
for (auto i = 0; i < _number_elements; i++)
156+
result._data[i] = _data[i] + rhs._data[i];
157+
}
158+
159+
return result;
160+
}
161+
162+
Tensor operator*(T value) const
163+
{
164+
Tensor result = make_copy();
165+
166+
for (auto i = 0; i < _number_elements; i++)
167+
result._data[i] = _data[i] * value;
168+
169+
return result;
170+
}
171+
172+
Tensor operator*(const Tensor& rhs) const
173+
{
174+
Tensor result = make_copy();
175+
176+
if (rhs._dims == _dims) // element-wise multiply
177+
{
178+
for (auto i = 0; i < _number_elements; i++)
179+
result._data[i] = _data[i] * rhs._data[i];
180+
}
181+
182+
return result;
183+
}
184+
185+
Tensor operator[](int index)
186+
{
187+
auto new_dimensions = _dims;
188+
new_dimensions.erase(new_dimensions.begin());
189+
190+
if (new_dimensions.size() == 0)
191+
new_dimensions.push_back(1);
192+
193+
auto data_offset = _data + index * std::accumulate(new_dimensions.begin(), new_dimensions.end(), 1, std::multiplies<size_t>());
194+
return Tensor(new_dimensions, data_offset);
195+
}
196+
197+
const Tensor operator[](int index) const
198+
{
199+
auto new_dimensions = _dims;
200+
new_dimensions.erase(new_dimensions.begin());
201+
202+
if (new_dimensions.size() == 0)
203+
new_dimensions.push_back(1);
204+
205+
auto data_offset = _data + index * std::accumulate(new_dimensions.begin(), new_dimensions.end(), 1, std::multiplies<size_t>());
206+
return Tensor(new_dimensions, data_offset);
207+
}
208+
209+
// deep copy operations
210+
Tensor make_copy() const
211+
{
212+
auto copied_tensor = Tensor(_dims);
213+
for (auto i = 0; i < _number_elements; i++)
214+
copied_tensor._data[i] = _data[i];
215+
216+
return copied_tensor;
217+
}
218+
219+
// Linear algebra
220+
void transpose()
221+
{
222+
auto tmp = _dims[0];
223+
_dims[0] = _dims[1];
224+
_dims[1] = tmp;
225+
226+
// in-place sort or some iterators magic requiered
227+
}
228+
229+
// String representation
230+
std::string to_string() const
231+
{
232+
std::string result = "";
233+
auto num_dims = _dims.size();
234+
235+
if (num_dims > 2)
236+
{
237+
for (auto i = 0; i < _dims[0]; i++)
238+
result += (*this)[i].to_string();
239+
}
240+
else if (num_dims == 2)
241+
{
242+
auto _h = _dims[0];
243+
auto _w = _dims[1];
244+
245+
result += "[";
246+
for (auto i = 0; i < _number_elements; i += _w)
247+
{
248+
if (i != 0)
249+
result += " [";
250+
else
251+
result += "[";
252+
253+
for (auto j = 0; j < _w; j++)
254+
{
255+
if (j != (_w - 1))
256+
result += std::to_string(_data[i+j]) + ", ";
257+
else
258+
result += std::to_string(_data[i+j]);
259+
}
260+
result += "]";
261+
262+
if (i != (_h - 1))
263+
result += "\n";
264+
}
265+
result += "]\n\n";
266+
}
267+
else
268+
{
269+
result += "[";
270+
for (auto i = 0; i < _dims[0]; i++)
271+
{
272+
if (i != (_dims[0] - 1))
273+
result += std::to_string(_data[i]) + ", ";
274+
else
275+
result += std::to_string(_data[i]);
276+
}
277+
result += "] ";
278+
}
279+
280+
return result.substr(0, result.length() - 2);
281+
}
282+
283+
private:
284+
// sub-tensor ctor
285+
Tensor(const dimensions& dims, T* data_start_address)
286+
: _is_shallow_copy(true),
287+
_is_sub_tensor(true),
288+
_data(data_start_address),
289+
_dims(dims)
290+
{
291+
_number_elements = std::accumulate(_dims.begin(), _dims.end(), 1, std::multiplies<size_t>());
292+
}
293+
294+
// Memory management
295+
void data_release()
296+
{
297+
if (!_is_shallow_copy)
298+
delete[] _data;
299+
}
300+
};
301+
302+
}}
303+
304+
#endif // CSLIB_DATA_STRUCTURE_TENSOR_HPP

0 commit comments

Comments
 (0)