Skip to content

Commit b2ce1c6

Browse files
committed
Fix test
Signed-off-by: Ryan Nett <[email protected]>
1 parent f0240d0 commit b2ce1c6

File tree

1 file changed

+8
-6
lines changed
  • tensorflow-kotlin-parent/tensorflow-core-kotlin/src/test/kotlin/org/tensorflow

1 file changed

+8
-6
lines changed

tensorflow-kotlin-parent/tensorflow-core-kotlin/src/test/kotlin/org/tensorflow/ExampleTest.kt

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ limitations under the License.
1616
*/
1717
package org.tensorflow
1818

19-
import kotlin.test.Test
2019
import org.tensorflow.ndarray.Shape
2120
import org.tensorflow.op.WithOps
2221
import org.tensorflow.op.kotlin.tf
2322
import org.tensorflow.types.TFloat32
24-
import org.tensorflow.types.TInt32
23+
import kotlin.test.Test
24+
import kotlin.test.assertEquals
2525

2626
private fun WithOps.DenseLayer(
2727
name: String,
@@ -30,6 +30,7 @@ private fun WithOps.DenseLayer(
3030
activation: WithOps.(Operand<TFloat32>) -> Operand<TFloat32> = { tf.nn.relu(it) },
3131
): Operand<TFloat32> =
3232
tf.withSubScope(name) {
33+
//TODO should be dynamic
3334
val inputDims = x.shape()[1]
3435
val W = tf.variable(tf.ones<TFloat32>(tf.array(inputDims.toInt(), n)))
3536
val b = tf.variable(tf.ones<TFloat32>(tf.array(n)))
@@ -44,15 +45,16 @@ public class ExampleTest {
4445
tf.placeholderWithDefault(
4546
tf.ones<TFloat32>(tf.array(1, 28, 28, 3)), Shape.of(-1, 28, 28, 3))
4647

47-
var x: Operand<TFloat32> = tf.reshape(input, tf.array(-1))
48-
tf.dtypes.cast<TInt32>(x)
48+
var x: Operand<TFloat32> = tf.reshape(input, tf.array(-1, 28 * 28 * 3))
4949
x = DenseLayer("Layer1", x, 256)
5050
x = DenseLayer("Layer2", x, 64)
51-
val output = DenseLayer("OutputLayer", x, 10) { tf.math.sigmoid(x) }
51+
val output = DenseLayer("OutputLayer", x, 10) { tf.math.sigmoid(it) }
5252

5353
useSession { session ->
54+
session.runInit()
5455
val outputValue = session.runner().fetch(output).run()[0] as TFloat32
55-
println(outputValue.getFloat(0))
56+
assertEquals(Shape.of(1, 10), outputValue.shape())
57+
assertEquals(1.0f, outputValue.getFloat(0, 0))
5658
}
5759
}
5860
}

0 commit comments

Comments
 (0)