diff --git a/tensorflow_quantum/python/util_test.py b/tensorflow_quantum/python/util_test.py index 00715899b..8860ba1f4 100644 --- a/tensorflow_quantum/python/util_test.py +++ b/tensorflow_quantum/python/util_test.py @@ -84,6 +84,61 @@ def test_get_supported_channels(self): len(serializer.SERIALIZER.supported_gate_types()) - len(util.get_supported_gates())) + @parameterized.named_parameters( + ('without_channels', False), + ('with_channels', True), + ) + def test_random_circuit_resolver_batch_shapes_and_types( + self, include_channels): + """Confirm random_circuit_resolver_batch returns the expected types.""" + qubits = cirq.GridQubit.rect(1, 3) + batch_size = 4 + + circuits, resolvers = util.random_circuit_resolver_batch( + qubits, batch_size, n_moments=5, include_channels=include_channels) + + self.assertLen(circuits, batch_size) + self.assertLen(resolvers, batch_size) + for circuit in circuits: + self.assertIsInstance(circuit, cirq.Circuit) + self.assertFalse(cirq.is_parameterized(circuit)) + for resolver in resolvers: + self.assertIsInstance(resolver, cirq.ParamResolver) + self.assertEmpty(resolver.param_dict) + + @parameterized.named_parameters( + ('without_channels', False), + ('with_channels', True), + ) + def test_random_symbol_circuit_resolver_batch_shapes_and_types( + self, include_channels): + """Confirm random_symbol_circuit_resolver_batch returns + the expected types.""" + qubits = cirq.GridQubit.rect(1, 3) + symbols = ['alpha', 'beta', 'gamma'] + batch_size = 4 + + circuits, resolvers = util.random_symbol_circuit_resolver_batch( + qubits, + symbols, + batch_size, + n_moments=5, + include_channels=include_channels) + + self.assertLen(circuits, batch_size) + self.assertLen(resolvers, batch_size) + for circuit in circuits: + self.assertIsInstance(circuit, cirq.Circuit) + self.assertSetEqual(set(util.get_circuit_symbols(circuit)), + set(symbols)) + for resolver in resolvers: + self.assertIsInstance(resolver, cirq.ParamResolver) + self.assertEqual(set(resolver.param_dict.keys()), set(symbols)) + self.assertTrue( + all( + isinstance(value, float) + for value in resolver.param_dict.values())) + @parameterized.parameters(_items_to_tensorize()) def test_convert_to_tensor(self, item): """Test that the convert_to_tensor function works correctly by manually