@@ -446,53 +446,47 @@ def create_constant(
446446 else :
447447 shape = list (torch_value .shape )
448448
449- if torch_value is not None :
450-
451- if torch_value .dtype == torch .uint8 :
452- if is_tensorrt_version_supported ("10.8.0" ):
453- if (
454- target_quantized_type is None
455- or target_quantized_type != trt .DataType .FP4
456- ):
457- # Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8
458- raise ValueError (
459- "Currently supported target_quantized_type for uint8 is FP4, got {target_quantized_type=}"
460- )
461- shape [- 1 ] = shape [- 1 ] * 2
462- weights = to_trt_weights (
463- ctx ,
464- torch_value ,
465- name ,
466- "CONSTANT" ,
467- "CONSTANT" ,
468- dtype = trt .DataType .FP4 ,
469- count = torch_value .numel () * 2 ,
470- )
471- constant = ctx .net .add_constant (
472- shape ,
473- weights ,
474- )
475- constant .name = name
476- return constant .get_output (0 )
477- else :
449+ if torch_value .dtype == torch .uint8 :
450+ if is_tensorrt_version_supported ("10.8.0" ):
451+ if (
452+ target_quantized_type is None
453+ or target_quantized_type != trt .DataType .FP4
454+ ):
455+ # Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8
478456 raise ValueError (
479- "Currently FP4 is only supported in TensorRT 10.8.0 and above "
457+ "Currently supported target_quantized_type for uint8 is FP4, got {target_quantized_type=} "
480458 )
481- # Record the weight in ctx for refit and cpu memory reference
459+ shape [- 1 ] = shape [- 1 ] * 2
460+ weights = to_trt_weights (
461+ ctx ,
462+ torch_value ,
463+ name ,
464+ "CONSTANT" ,
465+ "CONSTANT" ,
466+ dtype = trt .DataType .FP4 ,
467+ count = torch_value .numel () * 2 ,
468+ )
469+ constant = ctx .net .add_constant (
470+ shape ,
471+ weights ,
472+ )
473+ constant .name = name
474+ return constant .get_output (0 )
475+ else :
476+ raise ValueError (
477+ "Currently FP4 is only supported in TensorRT 10.8.0 and above"
478+ )
479+ # Record the weight in ctx for refit and cpu memory reference
482480
483- # Convert the torch.Tensor to a trt.Weights object
484- trt_weights = to_trt_weights (ctx , torch_value , name , "CONSTANT" , "CONSTANT" )
485- constant = ctx .net .add_constant (
486- shape ,
487- trt_weights ,
488- )
489- constant .name = name
481+ # Convert the torch.Tensor to a trt.Weights object
482+ trt_weights = to_trt_weights (ctx , torch_value , name , "CONSTANT" , "CONSTANT" )
483+ constant = ctx .net .add_constant (
484+ shape ,
485+ trt_weights ,
486+ )
487+ constant .name = name
490488
491- return constant .get_output (0 )
492- else :
493- raise ValueError (
494- f"Cannot convert tensor '{ name } ' to a TensorRT constant because its value is None."
495- )
489+ return constant .get_output (0 )
496490
497491
498492def get_trt_tensor (
0 commit comments