diff --git a/ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py b/ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py index 9ae7f50..f7b4107 100644 --- a/ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +++ b/ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py @@ -141,23 +141,38 @@ def materialize_fc_conv( weight_tensor, graph_info.buffers, ) - assert weight_content is not None - if not ( - quant_params := tensor_quant_params_cache.lookup( - weight_tensor.buffer, _FP16_QUANT_CONFIG + if weight_content is None: + # If the weight data is not stored in the flatbuffer (e.g. dynamic weights), + # skip quantization for this tensor. + op2weight_params = qtyping.OpToTensorParams( + subgraph_op_id=op_info.subgraph_op_index, + transformations=[_QuantTransformation.NO_QUANTIZE], + ) + else: + if not ( + quant_params := tensor_quant_params_cache.lookup( + weight_tensor.buffer, _FP16_QUANT_CONFIG + ) + ): + quant_params = qtyping.NonLinearQuantParams( + num_bits=16, quantized_data=weight_content.astype(np.float16) ) - ): + tensor_quant_params_cache.insert( + weight_tensor.buffer, _FP16_QUANT_CONFIG, quant_params + ) + op2weight_params = qtyping.OpToTensorParams( + subgraph_op_id=op_info.subgraph_op_index, + parameters=quant_params, + transformations=[_QuantTransformation.ADD_DEQUANTIZE], + ) quant_params = qtyping.NonLinearQuantParams( - num_bits=16, quantized_data=weight_content.astype(np.float16) + num_bits=16, quantized_data=weight_content.astype(np.float16) # pytype: disable=attribute-error ) - tensor_quant_params_cache.insert( - weight_tensor.buffer, _FP16_QUANT_CONFIG, quant_params + op2weight_params = qtyping.OpToTensorParams( + subgraph_op_id=op_info.subgraph_op_index, + parameters=quant_params, + transformations=[_QuantTransformation.ADD_DEQUANTIZE], ) - op2weight_params = qtyping.OpToTensorParams( - subgraph_op_id=op_info.subgraph_op_index, - parameters=quant_params, - transformations=[_QuantTransformation.ADD_DEQUANTIZE], - ) op_tensor_params.append( qtyping.TensorTransformationParams( tensor_name=tfl_flatbuffer_utils.get_tensor_name(weight_tensor), @@ -241,23 +256,30 @@ def materialize_conv2d_transpose( weight_tensor, graph_info.buffers, ) - assert weight_content is not None - if not ( - quant_params := tensor_quant_params_cache.lookup( - weight_tensor.buffer, _FP16_QUANT_CONFIG - ) - ): - quant_params = qtyping.NonLinearQuantParams( - num_bits=16, quantized_data=weight_content.astype(np.float16) + if weight_content is None: + # If the weight data is not stored in the flatbuffer (e.g. dynamic weights), + # skip quantization for this tensor. + op2weight_params = qtyping.OpToTensorParams( + subgraph_op_id=op_info.subgraph_op_index, + transformations=[_QuantTransformation.NO_QUANTIZE], ) - tensor_quant_params_cache.insert( - weight_tensor.buffer, _FP16_QUANT_CONFIG, quant_params + else: + if not ( + quant_params := tensor_quant_params_cache.lookup( + weight_tensor.buffer, _FP16_QUANT_CONFIG + ) + ): + quant_params = qtyping.NonLinearQuantParams( + num_bits=16, quantized_data=weight_content.astype(np.float16) + ) + tensor_quant_params_cache.insert( + weight_tensor.buffer, _FP16_QUANT_CONFIG, quant_params + ) + op2weight_params = qtyping.OpToTensorParams( + subgraph_op_id=op_info.subgraph_op_index, + parameters=quant_params, + transformations=[_QuantTransformation.ADD_DEQUANTIZE], ) - op2weight_params = qtyping.OpToTensorParams( - subgraph_op_id=op_info.subgraph_op_index, - parameters=quant_params, - transformations=[_QuantTransformation.ADD_DEQUANTIZE], - ) op_tensor_params.append( qtyping.TensorTransformationParams( tensor_name=tfl_flatbuffer_utils.get_tensor_name(weight_tensor), diff --git a/ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py b/ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py index 39deb0a..276c23d 100644 --- a/ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +++ b/ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py @@ -14,6 +14,7 @@ # ============================================================================== import pathlib +from unittest import mock from absl.testing import absltest from absl.testing import parameterized @@ -335,6 +336,50 @@ def test_conv2d_weight_only_succeeds(self): float_casting.materialize_fc_conv, ) + def test_conv2d_dynamic_weight_succeeds(self): + # Read from Model Explorer. + subgraph0 = self._test_model.subgraphs[0] + subgraph_op_id = 0 + op = subgraph0.operators[subgraph_op_id] + + op_info = qtyping.OpInfo( + op=op, + op_name=_TFLOpName.CONV_2D, + subgraph_op_index=subgraph_op_id, + op_quant_config=qtyping.OpQuantizationConfig( + weight_tensor_config=_TensorQuantConfig( + num_bits=16, dtype=qtyping.TensorDataType.FLOAT + ), + compute_precision=_ComputePrecision.FLOAT, # WEIGHT_ONLY. + explicit_dequantize=True, + ), + ) + + with mock.patch.object( + tfl_flatbuffer_utils, "get_tensor_data", return_value=None + ): + tensor_quant_params = float_casting.materialize_fc_conv( + op_info, + self._graph_info, + self._tensor_name_to_qsv, + tensor_quant_params_cache=common_utils.TensorQuantParamsCache(), + ) + + # 4 tensors (input, weight, bias, output) + self.assertLen(tensor_quant_params, 4) + # The weight tensor is at index 1. + weight_params = tensor_quant_params[1] + self.assertEqual(weight_params.tensor_name, "sequential/conv2d/Conv2D") + self.assertLen(weight_params.consumers, 1) + + # Transformation should be NO_QUANTIZE since we mocked get_tensor_data to + # None. + op_params = weight_params.consumers[0] + self.assertSequenceEqual( + op_params.transformations, [_QuantTransformation.NO_QUANTIZE] + ) + self.assertIsNone(op_params.parameters) + @parameterized.named_parameters( dict( testcase_name="invalid_fc", @@ -468,6 +513,70 @@ def test_conv2d_transpose_weight_only_succeeds(self): is_inbounding_tensor=True, ) + def test_conv2d_transpose_dynamic_weight_succeeds(self): + # Read from Model Explorer. + test_model_path = str( + pathlib.Path(_TEST_DATA_PREFIX_PATH) + / "single_conv2d_transpose_bias.tflite" + ) + + test_model = tfl_flatbuffer_utils.read_model(test_model_path) + # The test model has one subgraph for now. + graph_info = qtyping.GraphInfo( + subgraph_tensors=test_model.subgraphs[0].tensors, + buffers=test_model.buffers, + ) + + subgraph0 = test_model.subgraphs[0] + subgraph_op_id = 0 + op = subgraph0.operators[subgraph_op_id] + + op_info = qtyping.OpInfo( + op=op, + op_name=_TFLOpName.CONV_2D_TRANSPOSE, + subgraph_op_index=subgraph_op_id, + op_quant_config=qtyping.OpQuantizationConfig( + weight_tensor_config=_TensorQuantConfig( + num_bits=16, dtype=qtyping.TensorDataType.FLOAT + ), + compute_precision=_ComputePrecision.FLOAT, # WEIGHT_ONLY. + explicit_dequantize=True, + ), + ) + + with mock.patch.object( + tfl_flatbuffer_utils, "get_tensor_data", return_value=None + ): + tensor_quant_params = float_casting.materialize_conv2d_transpose( + op_info, + graph_info, + self._tensor_name_to_qsv, + tensor_quant_params_cache=common_utils.TensorQuantParamsCache(), + ) + + _, _, bias_tensor, _ = tfl_flatbuffer_utils.parse_fc_bmm_conv_tensors( + op_info.op, graph_info.subgraph_tensors + ) + + num_configs = 4 if bias_tensor is not None else 3 + self.assertLen(tensor_quant_params, num_configs) + + # The weight tensor is at index 1. + weight_params = tensor_quant_params[1] + self.assertEqual( + weight_params.tensor_name, + "sequential_5/conv2d_transpose_3/conv2d_transpose", + ) + self.assertLen(weight_params.consumers, 1) + + # Transformation should be NO_QUANTIZE since we mocked get_tensor_data to + # None. + op_params = weight_params.consumers[0] + self.assertSequenceEqual( + op_params.transformations, [_QuantTransformation.NO_QUANTIZE] + ) + self.assertIsNone(op_params.parameters) + def test_depthwise_conv2d_weight_only_succeeds(self): # Read from Model Explorer. test_model_path = str(