@@ -223,13 +223,20 @@ auto aten_registrations TORCHTRT_UNUSED =
223223 {c10::Symbol::fromQualString (" aten::slice" ),
224224 [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
225225 c10::List<c10::IValue> list = args.at (n->input (0 )).IValue ()->to <c10::List<c10::IValue>>();
226-
227226 int64_t start = 0 ;
227+ int64_t end = 9223372036854775807 ;
228228 auto startIVal = args.at (n->input (1 )).IValue ();
229+ auto endIVal = args.at (n->input (2 )).IValue ();
230+
229231 if (!startIVal->isNone ()) {
230232 start = args.at (n->input (1 )).unwrapToInt ();
231233 }
232- int64_t end = args.at (n->input (2 )).unwrapToInt ();
234+ if (!endIVal->isNone ()) {
235+ end = args.at (n->input (2 )).unwrapToInt ();
236+ }
237+ if (start > end) {
238+ LOG_DEBUG (" The end should be greater than start" );
239+ }
233240 int64_t step = args.at (n->input (3 )).unwrapToInt ();
234241
235242 const int64_t list_size = list.size ();
@@ -253,8 +260,9 @@ auto aten_registrations TORCHTRT_UNUSED =
253260
254261 return sliced_list;
255262 },
256- EvalOptions ().validSchemas (
257- {" aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> (t[])" })})
263+ EvalOptions ().validSchemas ({" aten::slice.t(t[] l, int? start=None, int? end=None, int step=1) -> (t[])" })})
264+ // EvalOptions().validSchemas(
265+ // {"aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> (t[])"})})
258266 .evaluator(
259267 {c10::Symbol::fromQualString (" aten::len" ),
260268 [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
@@ -896,8 +904,14 @@ auto aten_registrations TORCHTRT_UNUSED =
896904 auto step = args.at (n->input (2 )).unwrapToInt ();
897905 return start + idx * step;
898906 },
899- EvalOptions ().validSchemas ({" aten::__derive_index(int idx, int start, int step) -> int" })});
900-
907+ EvalOptions ().validSchemas ({" aten::__derive_index(int idx, int start, int step) -> int" })})
908+ .evaluator(
909+ {c10::Symbol::fromQualString (" aten::list" ),
910+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
911+ c10::List<c10::IValue> list = args.at (n->input (0 )).IValue ()->to <c10::List<c10::IValue>>();
912+ return list.copy ();
913+ },
914+ EvalOptions ().validSchemas ({" aten::list.t(t[] l) -> (t[])" })});
901915} // namespace
902916} // namespace evaluators
903917} // namespace conversion
0 commit comments