Skip to content

Commit ee43bda

Browse files
issue/682 Parameter supports TP, fix copy_from
1 parent c973b0b commit ee43bda

File tree

16 files changed

+335
-70
lines changed

16 files changed

+335
-70
lines changed

include/infinicore/context/context.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
namespace infinicore {
1212

1313
namespace context {
14-
void setDevice(Device device);
14+
void setDevice(Device device, bool force_cpu = false);
1515
Device getDevice();
1616
size_t getDeviceCount(Device::Type type);
1717

include/infinicore/device.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ class Device {
3939

4040
bool operator!=(const Device &other) const;
4141

42+
inline static Device cpu() {
43+
return Device(Type::CPU, 0);
44+
}
45+
4246
private:
4347
Type type_;
4448

include/infinicore/nn/parameter.hpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,19 @@ class Parameter : public Tensor {
99

1010
Parameter(const Shape &shape,
1111
const DataType &dtype,
12-
const Device &device);
12+
const Device &device,
13+
Size tp_dim = 0,
14+
Size tp_rank = 0,
15+
Size tp_size = 1);
1316

1417
void load_blob(const void *data);
18+
19+
void load(const Tensor &tensor);
20+
21+
protected:
22+
// Tensor parallel configs
23+
Size tp_dim_; // dimension partitioned
24+
Size tp_rank_; // rank of this partition among tp group
25+
Size tp_size_; // total number of partitions
1526
};
1627
} // namespace infinicore::nn

python/infinicore/context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ def get_device_count(device_type):
2323
return _infinicore.get_device_count(infinicore.device(device_type)._underlying.type)
2424

2525

26-
def set_device(device):
26+
def set_device(device, force_cpu=False):
2727
"""Set the current active device.
2828
2929
Args:
3030
device: The device to set as active
3131
"""
32-
_infinicore.set_device(device._underlying)
32+
_infinicore.set_device(device._underlying, force_cpu)
3333

3434

3535
def sync_stream():

src/infinicore-test/memory_test.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -709,9 +709,6 @@ TestResult PerformanceTest::testMemoryCopyPerformance() {
709709
return false;
710710
}
711711

712-
// Initialize source data
713-
std::memset(src_memory->data(), 0xAB, data_size);
714-
715712
auto start = std::chrono::high_resolution_clock::now();
716713

717714
// Perform memory copies

src/infinicore-test/test_nn_module.cc

Lines changed: 195 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,20 @@
33

44
namespace infinicore::test {
55

6+
// Helper function to format shape for logging
7+
inline std::string formatShape(const std::vector<size_t> &shape) {
8+
std::ostringstream oss;
9+
oss << "[";
10+
for (size_t i = 0; i < shape.size(); ++i) {
11+
if (i > 0) {
12+
oss << ", ";
13+
}
14+
oss << shape[i];
15+
}
16+
oss << "]";
17+
return oss.str();
18+
}
19+
620
// Test 1: Basic module operations (creation, parameters, state_dict, load_state_dict)
721
TestResult NNModuleTest::testBasicModuleCreation() {
822
return measureTime("BasicModuleOperations", [this]() {
@@ -115,6 +129,174 @@ TestResult NNModuleTest::testBasicModuleCreation() {
115129
});
116130
}
117131

132+
TestResult NNModuleTest::testTensorParallelParameters() {
133+
return measureTime("TensorParallelParameters", [this]() {
134+
try {
135+
spdlog::info("==========================================");
136+
spdlog::info("Testing Tensor Parallel Parameters");
137+
spdlog::info("==========================================");
138+
139+
auto device = infinicore::context::getDevice();
140+
141+
spdlog::info("Test Tensor Parallel Parameter");
142+
// Case 1: Partition along dimension 0 (row-wise partitioning)
143+
infinicore::nn::Parameter param_dim0({8, 4}, infinicore::DataType::F32, device, 0, 0, 2);
144+
if (param_dim0->shape() != std::vector<size_t>({4, 4})) {
145+
spdlog::error("TP dim0: Expected shape [4, 4], got [{}]", formatShape(param_dim0->shape()));
146+
return false;
147+
}
148+
spdlog::info("✓ TP dim0 parameter created with correct partitioned shape");
149+
// Case 2: Partition along dimension 1 (column-wise partitioning)
150+
infinicore::nn::Parameter param_dim1({8, 4}, infinicore::DataType::F32, device, 1, 0, 2);
151+
if (param_dim1->shape() != std::vector<size_t>({8, 2})) {
152+
spdlog::error("TP dim1: Expected shape [8, 2], got [{}]", formatShape(param_dim1->shape()));
153+
return false;
154+
}
155+
spdlog::info("✓ TP dim1 parameter created with correct partitioned shape");
156+
spdlog::info("✓ Parameter creation with tensor parallelism passed");
157+
158+
spdlog::info("Test Tensor Parallel Linear Module");
159+
auto w_data = std::vector<float>(32 * 64);
160+
auto b_data = std::vector<float>(32);
161+
for (size_t i = 0; i < 32; ++i) {
162+
for (size_t j = 0; j < 64; ++j) {
163+
w_data[i * 64 + j] = static_cast<float>(j);
164+
}
165+
b_data[i] = static_cast<float>(i);
166+
}
167+
{
168+
spdlog::info("Test tp_size=4 tp_dim=0");
169+
Size tp_size = 4;
170+
Size tp_dim = 0;
171+
std::vector<std::unique_ptr<MockLinearModule>> tp_modules;
172+
173+
for (Size tp_rank = 0; tp_rank < tp_size; ++tp_rank) {
174+
auto module = std::make_unique<MockLinearModule>(64, 32, device, tp_dim, tp_rank, tp_size);
175+
tp_modules.push_back(std::move(module));
176+
}
177+
178+
// Verify each partition has correct shape
179+
for (size_t i = 0; i < tp_modules.size(); ++i) {
180+
const auto &weight = tp_modules[i]->get_weight();
181+
const auto &bias = tp_modules[i]->get_bias();
182+
183+
// Weight should be partitioned along output dimension (dim 0)
184+
if (weight->shape() != std::vector<size_t>({8, 64})) { // 32/4 = 8
185+
spdlog::error("TP rank {}: Weight shape mismatch. Expected [8, 64], got [{}]",
186+
i, formatShape(weight->shape()));
187+
return false;
188+
}
189+
190+
// Bias should be partitioned along output dimension
191+
if (bias->shape() != std::vector<size_t>({8})) { // 32/4 = 8
192+
spdlog::error("TP rank {}: Bias shape mismatch. Expected [8], got [{}]",
193+
i, formatShape(bias->shape()));
194+
return false;
195+
}
196+
197+
spdlog::debug("TP rank {}: weight shape [{}], bias shape [{}]",
198+
i, formatShape(weight->shape()), formatShape(bias->shape()));
199+
200+
tp_modules[i]->load_parameter_from_blob("weight", w_data.data());
201+
tp_modules[i]->load_parameter_from_blob("bias", b_data.data());
202+
203+
auto weight_loaded = infinicore::Tensor::from_blob(
204+
w_data.data(),
205+
{32, 64},
206+
infinicore::DataType::F32,
207+
infinicore::Device::cpu())
208+
->narrow({{0, i * 8, 8}})
209+
->to(device); // Narrow to get the partition
210+
auto bias_loaded = infinicore::Tensor::from_blob(
211+
b_data.data(),
212+
{32},
213+
infinicore::DataType::F32,
214+
infinicore::Device::cpu())
215+
->narrow({{0, i * 8, 8}})
216+
->to(device); // Narrow to get the partition
217+
218+
if (!tensorsAllClose(tp_modules[i]->get_weight(), weight_loaded, 1e-6, 1e-6)) {
219+
spdlog::error("TP rank {}: Weight values do not match after load_parameter_from_blob", i);
220+
return false;
221+
}
222+
223+
if (!tensorsAllClose(tp_modules[i]->get_bias(), bias_loaded, 1e-6, 1e-6)) {
224+
spdlog::error("TP rank {}: Bias values do not match after load_parameter_from_blob", i);
225+
return false;
226+
}
227+
}
228+
}
229+
230+
{
231+
spdlog::info("Test tp_size=4 tp_dim=1");
232+
Size tp_size = 4;
233+
Size tp_dim = 1;
234+
std::vector<std::unique_ptr<MockLinearModule>> tp_modules;
235+
236+
for (Size tp_rank = 0; tp_rank < tp_size; ++tp_rank) {
237+
auto module = std::make_unique<MockLinearModule>(64, 32, device, tp_dim, tp_rank, tp_size);
238+
tp_modules.push_back(std::move(module));
239+
}
240+
241+
// Verify each partition has correct shape
242+
for (size_t i = 0; i < tp_modules.size(); ++i) {
243+
const auto &weight = tp_modules[i]->get_weight();
244+
const auto &bias = tp_modules[i]->get_bias();
245+
246+
// Weight should be partitioned along output dimension (dim 0)
247+
if (weight->shape() != std::vector<size_t>({32, 16})) { // 64/4 = 16
248+
spdlog::error("TP rank {}: Weight shape mismatch. Expected [32, 16], got [{}]",
249+
i, formatShape(weight->shape()));
250+
return false;
251+
}
252+
253+
// Bias should be partitioned along output dimension
254+
if (bias->shape() != std::vector<size_t>({32})) { // Bias not partitioned when tp_dim=1
255+
spdlog::error("TP rank {}: Bias shape mismatch. Expected [32], got [{}]",
256+
i, formatShape(bias->shape()));
257+
return false;
258+
}
259+
260+
spdlog::debug("TP rank {}: weight shape [{}], bias shape [{}]",
261+
i, formatShape(weight->shape()), formatShape(bias->shape()));
262+
;
263+
tp_modules[i]->load_parameter_from_blob("weight", w_data.data());
264+
tp_modules[i]->load_parameter_from_blob("bias", b_data.data());
265+
266+
auto weight_loaded = infinicore::Tensor::from_blob(
267+
w_data.data(),
268+
{32, 64},
269+
infinicore::DataType::F32,
270+
infinicore::Device::cpu())
271+
->narrow({{1, i * 16, 16}})
272+
->to(device); // Narrow to get the partition
273+
auto bias_loaded = infinicore::Tensor::from_blob(
274+
b_data.data(),
275+
{32},
276+
infinicore::DataType::F32,
277+
infinicore::Device::cpu())
278+
->to(device); // Narrow to get the partition
279+
if (!tensorsAllClose(tp_modules[i]->get_weight(), weight_loaded, 1e-6, 1e-6)) {
280+
spdlog::error("TP rank {}: Weight values do not match after load_parameter_from_blob", i);
281+
return false;
282+
}
283+
if (!tensorsAllClose(tp_modules[i]->get_bias(), bias_loaded, 1e-6, 1e-6)) {
284+
spdlog::error("TP rank {}: Bias values do not match after load_parameter_from_blob", i);
285+
return false;
286+
}
287+
}
288+
}
289+
290+
spdlog::info("=== All Tensor Parallel Parameter Tests Passed ===");
291+
return true;
292+
293+
} catch (const std::exception &e) {
294+
spdlog::error("Exception in testTensorParallelParameters: {}", e.what());
295+
return false;
296+
}
297+
});
298+
}
299+
118300
// Test 2: Advanced load state dict functionality (hierarchical modules)
119301
TestResult NNModuleTest::testLoadStateDict() {
120302
return measureTime("AdvancedLoadStateDict", [this]() {
@@ -384,6 +566,8 @@ TestResult NNModuleTest::testParameterLoading() {
384566
return false;
385567
}
386568

569+
MockLinearModule module_row_parallel(3, 2, infinicore::Device(), 0, 1, 2);
570+
387571
spdlog::info("Parameter loading test passed");
388572
return true;
389573
} catch (const std::exception &e) {
@@ -1708,16 +1892,17 @@ TestResult NNModuleTest::run() {
17081892
<< "InfiniCore nn::Module Test Suite\n"
17091893
<< "==============================================" << std::endl;
17101894

1711-
results.push_back(testBasicModuleCreation()); // Merged: creation + parameters + state_dict + load
1712-
results.push_back(testLoadStateDict()); // Advanced: hierarchical modules
1713-
results.push_back(testModuleHierarchy()); // Demonstrates hierarchical construction
1714-
results.push_back(testParameterLoading()); // Blob loading
1715-
results.push_back(testModuleLinear()); // Linear module comprehensive test
1716-
results.push_back(testModuleEmbedding()); // Embedding module test
1717-
results.push_back(testModuleRMSNorm()); // RMSNorm module test
1718-
results.push_back(testModuleRoPE()); // RoPE module test
1719-
results.push_back(testDtypeAssertion()); // Dtype assertion test
1720-
results.push_back(testTinyLlamaConstruction()); // Comprehensive: TinyLlama model test
1895+
results.push_back(testBasicModuleCreation()); // Merged: creation + parameters + state_dict + load
1896+
results.push_back(testTensorParallelParameters()); // Tensor-parallel parameters
1897+
results.push_back(testLoadStateDict()); // Advanced: hierarchical modules
1898+
results.push_back(testModuleHierarchy()); // Demonstrates hierarchical construction
1899+
results.push_back(testParameterLoading()); // Blob loading
1900+
results.push_back(testModuleLinear()); // Linear module comprehensive test
1901+
results.push_back(testModuleEmbedding()); // Embedding module test
1902+
results.push_back(testModuleRMSNorm()); // RMSNorm module test
1903+
results.push_back(testModuleRoPE()); // RoPE module test
1904+
results.push_back(testDtypeAssertion()); // Dtype assertion test
1905+
results.push_back(testTinyLlamaConstruction()); // Comprehensive: TinyLlama model test
17211906

17221907
// Check if all tests passed
17231908
bool all_passed = true;

src/infinicore-test/test_nn_module.h

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,25 @@ class MockLinearModule : public infinicore::nn::Module {
2626
INFINICORE_NN_PARAMETER(weight);
2727
INFINICORE_NN_PARAMETER(bias);
2828

29-
MockLinearModule(int input_size, int output_size, const infinicore::Device &device)
30-
: input_size_(input_size), output_size_(output_size), device_(device) {
29+
MockLinearModule(int input_size, int output_size, const infinicore::Device &device,
30+
Size tp_dim = 0, Size tp_rank = 0, Size tp_size = 1)
31+
: input_size_(input_size), output_size_(output_size), device_(device),
32+
tp_dim_(tp_dim), tp_rank_(tp_rank), tp_size_(tp_size) {
3133
// Initialize parameters using macros
3234
INFINICORE_NN_PARAMETER_INIT(weight,
3335
({static_cast<size_t>(output_size), static_cast<size_t>(input_size)},
3436
infinicore::DataType::F32,
35-
device));
37+
device,
38+
tp_dim_,
39+
tp_rank_,
40+
tp_size_));
3641
INFINICORE_NN_PARAMETER_INIT(bias,
3742
({static_cast<size_t>(output_size)},
3843
infinicore::DataType::F32,
39-
device));
44+
device,
45+
0,
46+
tp_dim == 0 ? tp_rank_ : 0,
47+
tp_dim == 0 ? tp_size_ : 1));
4048
}
4149

4250
// Simple forward pass (conceptual - would need actual matrix operations)
@@ -68,6 +76,10 @@ class MockLinearModule : public infinicore::nn::Module {
6876
int input_size_;
6977
int output_size_;
7078
infinicore::Device device_;
79+
80+
Size tp_dim_;
81+
Size tp_rank_;
82+
Size tp_size_;
7183
};
7284

7385
class NNModuleTest : public TestFramework {
@@ -76,16 +88,17 @@ class NNModuleTest : public TestFramework {
7688
std::string getName() const override { return "NNModuleTest"; }
7789

7890
private:
79-
TestResult testBasicModuleCreation(); // Merged: creation, parameters, state_dict, load_state_dict
80-
TestResult testLoadStateDict(); // Advanced: hierarchical modules
81-
TestResult testModuleHierarchy(); // Demonstrates proper hierarchical construction pattern
82-
TestResult testParameterLoading(); // Test blob parameter loading
83-
TestResult testModuleLinear(); // Comprehensive Linear module test
84-
TestResult testModuleEmbedding(); // Embedding module test
85-
TestResult testModuleRMSNorm(); // RMSNorm module test
86-
TestResult testModuleRoPE(); // RoPE module test
87-
TestResult testDtypeAssertion(); // Test dtype assertions when loading parameters
88-
TestResult testTinyLlamaConstruction(); // Comprehensive: construction + weight loading + validation
91+
TestResult testBasicModuleCreation(); // Merged: creation, parameters, state_dict, load_state_dict
92+
TestResult testTensorParallelParameters(); // Module with tensor parallel parameters
93+
TestResult testLoadStateDict(); // Advanced: hierarchical modules
94+
TestResult testModuleHierarchy(); // Demonstrates proper hierarchical construction pattern
95+
TestResult testParameterLoading(); // Test blob parameter loading
96+
TestResult testModuleLinear(); // Comprehensive Linear module test
97+
TestResult testModuleEmbedding(); // Embedding module test
98+
TestResult testModuleRMSNorm(); // RMSNorm module test
99+
TestResult testModuleRoPE(); // RoPE module test
100+
TestResult testDtypeAssertion(); // Test dtype assertions when loading parameters
101+
TestResult testTinyLlamaConstruction(); // Comprehensive: construction + weight loading + validation
89102
};
90103

91104
} // namespace infinicore::test

src/infinicore/context/context_impl.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,15 @@ Runtime *ContextImpl::getCpuRuntime() {
3333
return runtime_table_[int(Device::Type::CPU)][0].get();
3434
}
3535

36-
void ContextImpl::setDevice(Device device) {
36+
void ContextImpl::setDevice(Device device, bool force_cpu) {
3737
if (device == getCurrentRuntime()->device()) {
3838
// Do nothing if the device is already set.
3939
return;
4040
}
41+
if (device == Device(Device::Type::CPU, 0) && !force_cpu) {
42+
// if not forced, no need to switch to CPU device runtime
43+
return;
44+
}
4145

4246
if (runtime_table_[int(device.getType())][device.getIndex()] == nullptr) {
4347
// Lazy initialization of runtime if never set before.
@@ -83,8 +87,8 @@ ContextImpl::ContextImpl() {
8387

8488
namespace context {
8589

86-
void setDevice(Device device) {
87-
ContextImpl::singleton().setDevice(device);
90+
void setDevice(Device device, bool force_cpu) {
91+
ContextImpl::singleton().setDevice(device, force_cpu);
8892
}
8993

9094
Device getDevice() {

0 commit comments

Comments
 (0)