From 1e6e9d30ea1498dbcecf211b1cbcaebc40205e18 Mon Sep 17 00:00:00 2001 From: StijnWoestenborghs Date: Tue, 14 May 2024 14:34:53 +0200 Subject: [PATCH 1/6] itertools.product & unvectorized fw index --- basalt/autograd/ops/mlops.mojo | 78 +++++++++++++++++++++++++++++++++- basalt/autograd/ops/ops.mojo | 9 +++- basalt/utils/itertools.mojo | 47 ++++++++++++++++++++ tests/mojo/test_mlops.mojo | 76 +++++++++++++++++++++++---------- 4 files changed, 185 insertions(+), 25 deletions(-) create mode 100644 basalt/utils/itertools.mojo diff --git a/basalt/autograd/ops/mlops.mojo b/basalt/autograd/ops/mlops.mojo index 0869919..0f9bb1f 100644 --- a/basalt/autograd/ops/mlops.mojo +++ b/basalt/autograd/ops/mlops.mojo @@ -4,6 +4,7 @@ from math.limit import min_finite, max_finite from basalt import Tensor, TensorShape from basalt.utils.tensorutils import elwise_transform +from basalt.utils.itertools import product from basalt.autograd.attributes import Attribute, AttributeVector @@ -491,4 +492,79 @@ struct SLICE: Self.slice_kernel[ug_shape, t1_shape, steps, starts, ends, True](res_grad, ug) - return res_grad ^ \ No newline at end of file + return res_grad ^ + + +struct INDEX: + @staticmethod + fn adjust_boundary(slice: Int, dim_size: Int) -> Int: + # Adjust negative indices & ensure they are within bounds. + var s = slice if slice >= 0 else dim_size + slice + return max(min(s, dim_size), 0) + + @staticmethod + fn to_indeces(shape: TensorShape, attrs: AttributeVector) -> List[List[Int]]: + var SLICE_LITERALS = List[StringLiteral]("dim_0s", "dim_1s", "dim_2s", "dim_3s", "dim_4s", "dim_5s", "dim_6s", "dim_7s") + var INDEX_LITERALS = List[StringLiteral]("dim_0i", "dim_1i", "dim_2i", "dim_3i", "dim_4i", "dim_5i", "dim_6i", "dim_7i") + + var indeces = List[List[Int]]() + for dim in range(shape.rank()): + var temp = List[Int]() + + # Option 1: Slice + if attrs[SLICE_LITERALS[dim]]: + var slice = attrs[SLICE_LITERALS[dim]].value().to_shape() + var step = slice[2] if slice.rank() == 3 else 1 + for i in range( + start=Self.adjust_boundary(slice[0], shape[dim]), + end=Self.adjust_boundary(slice[1], shape[dim]), + step=step + ): + temp.append(i) + + # Option 2: Indeces + elif attrs[INDEX_LITERALS[dim]]: + var indeces = attrs[INDEX_LITERALS[dim]].value().to_shape() + for i in range(indeces.rank()): + temp.append(indeces[i]) + + # All indeces + else: + for i in range(shape[dim]): + temp.append(i) + + indeces.append(temp) + + return indeces ^ + + @staticmethod + fn result_shape(shape: TensorShape, attrs: AttributeVector) -> TensorShape: + var indeces = Self.to_indeces(shape, attrs) + var new_shape = List[Int]() + for i in range(shape.rank()): + new_shape.append(len(indeces[i])) + return TensorShape(new_shape) + + @staticmethod + fn forward[ + t1_shape: TensorShape, + attributes: AttributeVector, + ](inout res: Tensor[dtype], t1: Tensor[dtype]): + alias indeces = Self.to_indeces(t1_shape, attributes) + alias strides = t1_shape.strides() + + var j = 0 + for comb in product(indeces): + var flat_index = 0 + for dim in range(t1_shape.rank()): + flat_index += comb[dim] * strides[dim] + res[j] = t1[flat_index] + j += 1 + + @staticmethod + fn backward[ + ug_shape: TensorShape, + t1_shape: TensorShape, + attributes: AttributeVector = AttributeVector(), + ](ug: Tensor[dtype], t1: Tensor[dtype]) -> Tensor[dtype]: + return Tensor[dtype]() \ No newline at end of file diff --git a/basalt/autograd/ops/ops.mojo b/basalt/autograd/ops/ops.mojo index 7198270..c737821 100644 --- a/basalt/autograd/ops/ops.mojo +++ b/basalt/autograd/ops/ops.mojo @@ -15,7 +15,7 @@ from .basics import ( TRANSPOSE, FMA, ) -from .mlops import SIGMOID, RELU, TANH, CLIP, SQUEEZE, UNSQUEEZE, SLICE +from .mlops import SIGMOID, RELU, TANH, CLIP, SQUEEZE, UNSQUEEZE, SLICE, INDEX from .dynamics import CONCAT, SPLIT from .conv import CONV2D from .pool import MAXPOOL2D @@ -61,6 +61,7 @@ struct OP(Stringable): alias CONCAT = OP(23, "CONCAT", dynamic=True) alias SPLIT = OP(24, "SPLIT", dynamic=True) alias SLICE = OP(25, "SLICE") + alias INDEX = OP(26, "INDEX") var id: UInt8 var name: Bytes[16] @@ -135,6 +136,8 @@ fn static_result_shape( return UNSQUEEZE.result_shape(t1_shape, attributes) elif op == OP.SLICE: return SLICE.result_shape(t1_shape, attributes) + elif op == OP.INDEX: + return INDEX.result_shape(t1_shape, attributes) else: print("[ERROR] Operator not found.") return TensorShape(-1) @@ -249,6 +252,8 @@ fn forward_op[ UNSQUEEZE.forward[t1_shape, attributes](res, t1) elif op == OP.SLICE: SLICE.forward[t1_shape, attributes](res, t1) + elif op == OP.INDEX: + INDEX.forward[t1_shape, attributes](res, t1) else: print("[ERROR] Operator not found.") @@ -361,6 +366,8 @@ fn backward_op[ res_grad = UNSQUEEZE.backward[ug_shape, t1_shape](ug, t1) elif op == OP.SLICE: res_grad = SLICE.backward[ug_shape, t1_shape, attributes](ug, t1) + elif op == OP.INDEX: + res_grad = INDEX.backward[ug_shape, t1_shape, attributes](ug, t1) else: print("[ERROR] Operator not found.") res_grad = Tensor[dtype](-1) diff --git a/basalt/utils/itertools.mojo b/basalt/utils/itertools.mojo new file mode 100644 index 0000000..aceda31 --- /dev/null +++ b/basalt/utils/itertools.mojo @@ -0,0 +1,47 @@ + +@value +struct _ProductIterator(Sized): + var lists: List[List[Int]] + var indeces: List[Int] + var _iters: Int + + @always_inline("nodebug") + fn __init__(inout self, lists: List[List[Int]]): + self.lists = lists + self.indeces = List[Int]() + for i in range(len(lists)): + self.indeces.append(0) + + self._iters = 1 + for lst in self.lists: + self._iters *= len(lst[]) + + @always_inline("nodebug") + fn __len__(self) -> Int: + return self._iters + + @always_inline("nodebug") + fn __iter__(self) -> Self: + return self + + @always_inline("nodebug") + fn __next__(inout self) -> List[Int]: + var res = List[Int]() + for i in range(len(self.lists)): + res.append(self.lists[i][self.indeces[i]]) + self._increment_indeces() + self._iters -= 1 + return res ^ + + @always_inline("nodebug") + fn _increment_indeces(inout self): + for i in reversed(range(len(self.indeces))): + self.indeces[i] += 1 + if self.indeces[i] < len(self.lists[i]): + break + self.indeces[i] = 0 + + +@always_inline("nodebug") +fn product(lists: List[List[Int]]) -> _ProductIterator: + return _ProductIterator(lists) \ No newline at end of file diff --git a/tests/mojo/test_mlops.mojo b/tests/mojo/test_mlops.mojo index 2ba723e..4d87bb1 100644 --- a/tests/mojo/test_mlops.mojo +++ b/tests/mojo/test_mlops.mojo @@ -620,33 +620,63 @@ fn test_backward_SLICE_multiple_axes() raises: ](t1, ug, expected_ug) +from basalt.autograd.ops.mlops import INDEX + +fn test_INDEX() raises: + alias t1_shape = TensorShape(2, 3, 5) + var t = Tensor[dtype](t1_shape) + for i in range(t.num_elements()): + t[i] = i + + # t[:, [0, 0], 0:5:2] + # TODO: need for a list attribute as this only supports to specify indeces of MAX_RANK + alias attr_1 = Attribute("dim_1i", TensorShape(0, 0)) + alias attr_2 = Attribute("dim_2s", TensorShape(0, 5, 2)) + + var expected = Tensor[dtype](2, 2, 3) + for i in range(2): + for j in range(2): + for k in range(3): + expected[i*2*3 + j*3 + k] = i * 3 * 5 + k * 2 + + test_unary_op[ + OP.INDEX, t1_shape, AttributeVector( + attr_1, + attr_2, + ) + ](t, expected) + + print(expected) + + fn main(): try: - test_SIGMOID() - test_RELU() - test_TANH() - test_CLIP() - test_SQUEEZE() - test_UNSQUEEZE() - test_SLICE() - test_SLICE_step() - test_SLICE_neg() - test_SLICE_multiple_axes() + # test_SIGMOID() + # test_RELU() + # test_TANH() + # test_CLIP() + # test_SQUEEZE() + # test_UNSQUEEZE() + # test_SLICE() + # test_SLICE_step() + # test_SLICE_neg() + # test_SLICE_multiple_axes() + test_INDEX() except e: print("[ERROR] Error in forward mlops") print(e) return - try: - test_backward_SIGMOID() - test_backward_RELU() - test_backward_TANH() - test_backward_CLIP() - test_backward_SQUEEZE() - test_backward_UNSQUEEZE() - test_backward_SLICE() - test_backward_SLICE_multiple_axes() - except e: - print("[ERROR] Error in backward mlops") - print(e) - return + # try: + # test_backward_SIGMOID() + # test_backward_RELU() + # test_backward_TANH() + # test_backward_CLIP() + # test_backward_SQUEEZE() + # test_backward_UNSQUEEZE() + # test_backward_SLICE() + # test_backward_SLICE_multiple_axes() + # except e: + # print("[ERROR] Error in backward mlops") + # print(e) + # return From 60e510844596ede4fe51a22ed1c54fe27fc50292 Mon Sep 17 00:00:00 2001 From: StijnWoestenborghs Date: Tue, 14 May 2024 14:56:27 +0200 Subject: [PATCH 2/6] unoptimized index bw --- basalt/autograd/ops/mlops.mojo | 15 ++++++++- tests/mojo/test_mlops.mojo | 57 ++++++++++++++++++++++++++-------- 2 files changed, 58 insertions(+), 14 deletions(-) diff --git a/basalt/autograd/ops/mlops.mojo b/basalt/autograd/ops/mlops.mojo index 0f9bb1f..5aa2d8b 100644 --- a/basalt/autograd/ops/mlops.mojo +++ b/basalt/autograd/ops/mlops.mojo @@ -567,4 +567,17 @@ struct INDEX: t1_shape: TensorShape, attributes: AttributeVector = AttributeVector(), ](ug: Tensor[dtype], t1: Tensor[dtype]) -> Tensor[dtype]: - return Tensor[dtype]() \ No newline at end of file + alias indeces = Self.to_indeces(t1_shape, attributes) + alias strides = t1_shape.strides() + + var res_grad = Tensor[dtype](t1_shape) + + var j = 0 + for comb in product(indeces): + var flat_index = 0 + for dim in range(t1_shape.rank()): + flat_index += comb[dim] * strides[dim] + res_grad[flat_index] += ug[j] + j += 1 + + return res_grad^ \ No newline at end of file diff --git a/tests/mojo/test_mlops.mojo b/tests/mojo/test_mlops.mojo index 4d87bb1..964e134 100644 --- a/tests/mojo/test_mlops.mojo +++ b/tests/mojo/test_mlops.mojo @@ -649,6 +649,36 @@ fn test_INDEX() raises: print(expected) +fn test_INDEX_backward() raises: + alias t1_shape = TensorShape(2, 3, 5) + var t = Tensor[dtype](t1_shape) + for i in range(t.num_elements()): + t[i] = i + + alias attr_1 = Attribute("dim_1i", TensorShape(0, 0)) + alias attr_2 = Attribute("dim_2s", TensorShape(0, 5, 2)) + + alias ug_shape = TensorShape(2, 2, 3) + var ug = Tensor[dtype](ug_shape) + fill(ug, 1.0) + + var expected = Tensor[dtype](t1_shape) + for i in range(2): + for j in range(2): + for k in range(3): + # NOTE: `+=` because selected indeces [0, 0] can repeat + expected[i * 3 * 5 + k * 2] += 1.0 + + test_unary_op_backward[ + OP.INDEX, t1_shape, ug_shape, AttributeVector( + attr_1, + attr_2, + ) + ](t, ug, expected) + + print(expected) + + fn main(): try: # test_SIGMOID() @@ -667,16 +697,17 @@ fn main(): print(e) return - # try: - # test_backward_SIGMOID() - # test_backward_RELU() - # test_backward_TANH() - # test_backward_CLIP() - # test_backward_SQUEEZE() - # test_backward_UNSQUEEZE() - # test_backward_SLICE() - # test_backward_SLICE_multiple_axes() - # except e: - # print("[ERROR] Error in backward mlops") - # print(e) - # return + try: + # test_backward_SIGMOID() + # test_backward_RELU() + # test_backward_TANH() + # test_backward_CLIP() + # test_backward_SQUEEZE() + # test_backward_UNSQUEEZE() + # test_backward_SLICE() + # test_backward_SLICE_multiple_axes() + test_INDEX_backward() + except e: + print("[ERROR] Error in backward mlops") + print(e) + return From 8d90c09d852356ac5f30c8328d5549d785e9c827 Mon Sep 17 00:00:00 2001 From: StijnWoestenborghs Date: Tue, 14 May 2024 18:10:38 +0200 Subject: [PATCH 3/6] getindex to product & vectorized fw --- basalt/autograd/ops/mlops.mojo | 55 +++++++++++++++++++++++++++++----- basalt/utils/itertools.mojo | 34 +++++++++++---------- 2 files changed, 65 insertions(+), 24 deletions(-) diff --git a/basalt/autograd/ops/mlops.mojo b/basalt/autograd/ops/mlops.mojo index 5aa2d8b..fd871fd 100644 --- a/basalt/autograd/ops/mlops.mojo +++ b/basalt/autograd/ops/mlops.mojo @@ -545,6 +545,26 @@ struct INDEX: new_shape.append(len(indeces[i])) return TensorShape(new_shape) + @staticmethod + fn map_indeces[ + nelts: Int, + strides: TensorShape, + indeces: List[List[Int]], + ](idx: Int) -> SIMD[DType.int64, nelts]: + alias indeces_product = product(indeces) + + var temp = SIMD[DType.int64, nelts]() + for i in range(idx, idx + nelts): + var comb = indeces_product[i] + var flat_index = 0 + + for dim in range(len(comb)): + flat_index += comb[dim] * strides[dim] + + temp[i % nelts] = flat_index + + return temp + @staticmethod fn forward[ t1_shape: TensorShape, @@ -552,14 +572,17 @@ struct INDEX: ](inout res: Tensor[dtype], t1: Tensor[dtype]): alias indeces = Self.to_indeces(t1_shape, attributes) alias strides = t1_shape.strides() + alias total_length = len(product(indeces)) + + @parameter + fn vec_index[nelts: Int](i: Int): + + res.store[nelts](i, + t1.data().gather(Self.map_indeces[nelts, strides, indeces](i)) + ) + + vectorize[vec_index, nelts](total_length) - var j = 0 - for comb in product(indeces): - var flat_index = 0 - for dim in range(t1_shape.rank()): - flat_index += comb[dim] * strides[dim] - res[j] = t1[flat_index] - j += 1 @staticmethod fn backward[ @@ -569,9 +592,25 @@ struct INDEX: ](ug: Tensor[dtype], t1: Tensor[dtype]) -> Tensor[dtype]: alias indeces = Self.to_indeces(t1_shape, attributes) alias strides = t1_shape.strides() + # alias total_length = len(product(indeces)) var res_grad = Tensor[dtype](t1_shape) + # @parameter + # fn vec_index[nelts: Int](i: Int): + + # var offset = Self.map_indeces[nelts, strides, indeces](i) + # res_grad.data().scatter( + # offset, + # res_grad.data().gather(offset) + ug.load[nelts](i), + # ) + + # vectorize[vec_index, nelts](total_length) + + # BUG: Edge case in vectorization: + # When the offset = [0, 2, 4, 0] and ug = [1, 1, 1, 1] + # It doesn't scatter to index 0 twice as it should be: res_grad[0] += 1 + 1 + var j = 0 for comb in product(indeces): var flat_index = 0 @@ -579,5 +618,5 @@ struct INDEX: flat_index += comb[dim] * strides[dim] res_grad[flat_index] += ug[j] j += 1 - + return res_grad^ \ No newline at end of file diff --git a/basalt/utils/itertools.mojo b/basalt/utils/itertools.mojo index aceda31..fd7a6ce 100644 --- a/basalt/utils/itertools.mojo +++ b/basalt/utils/itertools.mojo @@ -2,16 +2,14 @@ @value struct _ProductIterator(Sized): var lists: List[List[Int]] - var indeces: List[Int] + var _current: Int var _iters: Int @always_inline("nodebug") fn __init__(inout self, lists: List[List[Int]]): self.lists = lists - self.indeces = List[Int]() - for i in range(len(lists)): - self.indeces.append(0) - + self._current = 0 + self._iters = 1 for lst in self.lists: self._iters *= len(lst[]) @@ -26,20 +24,24 @@ struct _ProductIterator(Sized): @always_inline("nodebug") fn __next__(inout self) -> List[Int]: - var res = List[Int]() - for i in range(len(self.lists)): - res.append(self.lists[i][self.indeces[i]]) - self._increment_indeces() + self._current += 1 self._iters -= 1 - return res ^ + return self._get_combination(self._current - 1) + + @always_inline("nodebug") + fn _get_combination(self, current: Int) -> List[Int]: + var combination = List[Int]() + var count = current + for i in reversed(range(len(self.lists))): + var index = count % len(self.lists[i]) + combination.append(self.lists[i][index]) + count //= len(self.lists[i]) + combination._reverse() + return combination ^ @always_inline("nodebug") - fn _increment_indeces(inout self): - for i in reversed(range(len(self.indeces))): - self.indeces[i] += 1 - if self.indeces[i] < len(self.lists[i]): - break - self.indeces[i] = 0 + fn __getitem__(self, index: Int) -> List[Int]: + return self._get_combination(index) @always_inline("nodebug") From 113b5aeabbf5479638addbf3b9afe24c189c73aa Mon Sep 17 00:00:00 2001 From: StijnWoestenborghs Date: Tue, 14 May 2024 19:19:02 +0200 Subject: [PATCH 4/6] something inbetween --- basalt/autograd/ops/mlops.mojo | 39 ++++++++++++++++------------------ 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/basalt/autograd/ops/mlops.mojo b/basalt/autograd/ops/mlops.mojo index fd871fd..6e38aaa 100644 --- a/basalt/autograd/ops/mlops.mojo +++ b/basalt/autograd/ops/mlops.mojo @@ -592,31 +592,28 @@ struct INDEX: ](ug: Tensor[dtype], t1: Tensor[dtype]) -> Tensor[dtype]: alias indeces = Self.to_indeces(t1_shape, attributes) alias strides = t1_shape.strides() - # alias total_length = len(product(indeces)) + alias total_length = len(product(indeces)) var res_grad = Tensor[dtype](t1_shape) - # @parameter - # fn vec_index[nelts: Int](i: Int): + @parameter + fn vec_index[nelts: Int](i: Int): - # var offset = Self.map_indeces[nelts, strides, indeces](i) - # res_grad.data().scatter( - # offset, - # res_grad.data().gather(offset) + ug.load[nelts](i), - # ) - - # vectorize[vec_index, nelts](total_length) - - # BUG: Edge case in vectorization: - # When the offset = [0, 2, 4, 0] and ug = [1, 1, 1, 1] - # It doesn't scatter to index 0 twice as it should be: res_grad[0] += 1 + 1 + var offset = Self.map_indeces[nelts, strides, indeces](i) + + # res_grad.data().scatter( + # offset, + # res_grad.data().gather(offset) + ug.load[nelts](i), + # ) + # BUG: Edge case in vectorization: + # When the offset = [0, 2, 4, 0] and ug = [1, 1, 1, 1] + # It doesn't scatter to index 0 twice as it should be: res_grad[0] += 1 + 1 + + # Workaround + var u = ug.load[nelts](i) + for j in range(nelts): + res_grad[int(offset[j])] += u[j] - var j = 0 - for comb in product(indeces): - var flat_index = 0 - for dim in range(t1_shape.rank()): - flat_index += comb[dim] * strides[dim] - res_grad[flat_index] += ug[j] - j += 1 + vectorize[vec_index, nelts](total_length) return res_grad^ \ No newline at end of file From 854c13172fd40ac52b837c125dc09673962ce74b Mon Sep 17 00:00:00 2001 From: benny-nottonson Date: Fri, 17 May 2024 16:08:36 -0700 Subject: [PATCH 5/6] added list.reserve() where possible Signed-off-by: benny-nottonson --- basalt/autograd/ops/mlops.mojo | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/basalt/autograd/ops/mlops.mojo b/basalt/autograd/ops/mlops.mojo index 6e38aaa..b4ec423 100644 --- a/basalt/autograd/ops/mlops.mojo +++ b/basalt/autograd/ops/mlops.mojo @@ -507,8 +507,11 @@ struct INDEX: var SLICE_LITERALS = List[StringLiteral]("dim_0s", "dim_1s", "dim_2s", "dim_3s", "dim_4s", "dim_5s", "dim_6s", "dim_7s") var INDEX_LITERALS = List[StringLiteral]("dim_0i", "dim_1i", "dim_2i", "dim_3i", "dim_4i", "dim_5i", "dim_6i", "dim_7i") + var rank = shape.rank() var indeces = List[List[Int]]() - for dim in range(shape.rank()): + indeces.reserve(rank) + + for dim in range(rank): var temp = List[Int]() # Option 1: Slice @@ -540,8 +543,10 @@ struct INDEX: @staticmethod fn result_shape(shape: TensorShape, attrs: AttributeVector) -> TensorShape: var indeces = Self.to_indeces(shape, attrs) + var rank = shape.rank() var new_shape = List[Int]() - for i in range(shape.rank()): + new_shape.reserve(rank) + for i in range(rank): new_shape.append(len(indeces[i])) return TensorShape(new_shape) From 0efb4e46575a4c463be0c09af8770e783539f80a Mon Sep 17 00:00:00 2001 From: StijnWoestenborghs Date: Fri, 31 May 2024 11:00:12 +0200 Subject: [PATCH 6/6] index bw notes --- basalt/autograd/ops/mlops.mojo | 10 +++++--- tests/mojo/test_mlops.mojo | 42 +++++++++++++++------------------- 2 files changed, 25 insertions(+), 27 deletions(-) diff --git a/basalt/autograd/ops/mlops.mojo b/basalt/autograd/ops/mlops.mojo index b4ec423..29f5e39 100644 --- a/basalt/autograd/ops/mlops.mojo +++ b/basalt/autograd/ops/mlops.mojo @@ -610,10 +610,14 @@ struct INDEX: # offset, # res_grad.data().gather(offset) + ug.load[nelts](i), # ) - # BUG: Edge case in vectorization: - # When the offset = [0, 2, 4, 0] and ug = [1, 1, 1, 1] - # It doesn't scatter to index 0 twice as it should be: res_grad[0] += 1 + 1 + # NOTE: Scatter (reduce SUM) required + # When the offset = [0, 2, 4, 0] and ug = [1, 1, 1, 1] + # The standard scatter will overwrite the values with overlapping indices. + # It doesn't accumulate index 0 twice as it should be: res_grad[0] += 1 + 1 + # cfr. https://github.com/ml-explore/mlx/blob/main/mlx/backend/common/indexing.cpp#L256-L258 + # cfr. https://github.com/modularml/mojo/blob/main/stdlib/src/sys/intrinsics.mojo#L903 + # Workaround var u = ug.load[nelts](i) for j in range(nelts): diff --git a/tests/mojo/test_mlops.mojo b/tests/mojo/test_mlops.mojo index 964e134..4085d48 100644 --- a/tests/mojo/test_mlops.mojo +++ b/tests/mojo/test_mlops.mojo @@ -620,8 +620,6 @@ fn test_backward_SLICE_multiple_axes() raises: ](t1, ug, expected_ug) -from basalt.autograd.ops.mlops import INDEX - fn test_INDEX() raises: alias t1_shape = TensorShape(2, 3, 5) var t = Tensor[dtype](t1_shape) @@ -646,8 +644,6 @@ fn test_INDEX() raises: ) ](t, expected) - print(expected) - fn test_INDEX_backward() raises: alias t1_shape = TensorShape(2, 3, 5) @@ -676,21 +672,19 @@ fn test_INDEX_backward() raises: ) ](t, ug, expected) - print(expected) - fn main(): try: - # test_SIGMOID() - # test_RELU() - # test_TANH() - # test_CLIP() - # test_SQUEEZE() - # test_UNSQUEEZE() - # test_SLICE() - # test_SLICE_step() - # test_SLICE_neg() - # test_SLICE_multiple_axes() + test_SIGMOID() + test_RELU() + test_TANH() + test_CLIP() + test_SQUEEZE() + test_UNSQUEEZE() + test_SLICE() + test_SLICE_step() + test_SLICE_neg() + test_SLICE_multiple_axes() test_INDEX() except e: print("[ERROR] Error in forward mlops") @@ -698,14 +692,14 @@ fn main(): return try: - # test_backward_SIGMOID() - # test_backward_RELU() - # test_backward_TANH() - # test_backward_CLIP() - # test_backward_SQUEEZE() - # test_backward_UNSQUEEZE() - # test_backward_SLICE() - # test_backward_SLICE_multiple_axes() + test_backward_SIGMOID() + test_backward_RELU() + test_backward_TANH() + test_backward_CLIP() + test_backward_SQUEEZE() + test_backward_UNSQUEEZE() + test_backward_SLICE() + test_backward_SLICE_multiple_axes() test_INDEX_backward() except e: print("[ERROR] Error in backward mlops")