@@ -16,12 +16,12 @@ limitations under the License.
1616*/
1717package org.tensorflow
1818
19- import kotlin.test.Test
2019import org.tensorflow.ndarray.Shape
2120import org.tensorflow.op.WithOps
2221import org.tensorflow.op.kotlin.tf
2322import org.tensorflow.types.TFloat32
24- import org.tensorflow.types.TInt32
23+ import kotlin.test.Test
24+ import kotlin.test.assertEquals
2525
2626private fun WithOps.DenseLayer (
2727 name : String ,
@@ -30,6 +30,7 @@ private fun WithOps.DenseLayer(
3030 activation : WithOps .(Operand <TFloat32 >) -> Operand <TFloat32 > = { tf.nn.relu(it) },
3131): Operand <TFloat32 > =
3232 tf.withSubScope(name) {
33+ // TODO should be dynamic
3334 val inputDims = x.shape()[1 ]
3435 val W = tf.variable(tf.ones<TFloat32 >(tf.array(inputDims.toInt(), n)))
3536 val b = tf.variable(tf.ones<TFloat32 >(tf.array(n)))
@@ -44,15 +45,16 @@ public class ExampleTest {
4445 tf.placeholderWithDefault(
4546 tf.ones<TFloat32 >(tf.array(1 , 28 , 28 , 3 )), Shape .of(- 1 , 28 , 28 , 3 ))
4647
47- var x: Operand <TFloat32 > = tf.reshape(input, tf.array(- 1 ))
48- tf.dtypes.cast<TInt32 >(x)
48+ var x: Operand <TFloat32 > = tf.reshape(input, tf.array(- 1 , 28 * 28 * 3 ))
4949 x = DenseLayer (" Layer1" , x, 256 )
5050 x = DenseLayer (" Layer2" , x, 64 )
51- val output = DenseLayer (" OutputLayer" , x, 10 ) { tf.math.sigmoid(x ) }
51+ val output = DenseLayer (" OutputLayer" , x, 10 ) { tf.math.sigmoid(it ) }
5252
5353 useSession { session ->
54+ session.runInit()
5455 val outputValue = session.runner().fetch(output).run ()[0 ] as TFloat32
55- println (outputValue.getFloat(0 ))
56+ assertEquals(Shape .of(1 , 10 ), outputValue.shape())
57+ assertEquals(1.0f , outputValue.getFloat(0 , 0 ))
5658 }
5759 }
5860 }
0 commit comments