From c20b71023b74b98aca1c8fb9086f4d5b4daed0d9 Mon Sep 17 00:00:00 2001 From: moradology Date: Thu, 6 Feb 2020 16:27:05 -0500 Subject: [PATCH] Checkpoint commit; not yet compiling --- .../scala/geotrellis/raster/ArrowTensor.scala | 12 +-- .../geotrellis/raster/BufferedTensor.scala | 2 +- .../org/apache/spark/sql/rf/TensorUDT.scala | 50 ++++++---- .../expressions/BinaryLocalRasterOp.scala | 5 +- .../expressions/DynamicExtractors.scala | 3 +- .../RasterSourcesToTensorRefs.scala | 11 +-- .../transformers/TensorRefToTensor.scala | 4 +- .../model/TensorDataContext.scala | 65 +++++++++++++ .../rasterframes/model/Voxels.scala | 81 ++++++++++++++++ .../rasterframes/ref/DeferredTensorRef.scala | 42 +++++++++ .../rasterframes/ref/TensorRef.scala | 28 +++--- .../tensors/InternalRowTensor.scala | 92 +++++++++++++++++++ .../rasterframes/tensors/RFTensor.scala | 45 +++++++++ 13 files changed, 386 insertions(+), 54 deletions(-) create mode 100644 core/src/main/scala/org/locationtech/rasterframes/model/TensorDataContext.scala create mode 100644 core/src/main/scala/org/locationtech/rasterframes/model/Voxels.scala create mode 100644 core/src/main/scala/org/locationtech/rasterframes/ref/DeferredTensorRef.scala create mode 100644 core/src/main/scala/org/locationtech/rasterframes/tensors/InternalRowTensor.scala create mode 100644 core/src/main/scala/org/locationtech/rasterframes/tensors/RFTensor.scala diff --git a/core/src/main/scala/geotrellis/raster/ArrowTensor.scala b/core/src/main/scala/geotrellis/raster/ArrowTensor.scala index 8a731e12e..b4b87c8de 100644 --- a/core/src/main/scala/geotrellis/raster/ArrowTensor.scala +++ b/core/src/main/scala/geotrellis/raster/ArrowTensor.scala @@ -15,16 +15,16 @@ import org.apache.arrow.vector.{Float8Vector, VectorSchemaRoot} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.locationtech.rasterframes.encoders.CatalystSerializerEncoder +import org.locationtech.rasterframes.tensors.RFTensor import spire.syntax.cfor._ import scala.collection.JavaConverters._ -case class ArrowTensor(val vector: Float8Vector, val shape: Seq[Int]) extends CellGrid { + + +case class ArrowTensor(val vector: Float8Vector, val shape: Seq[Int]) extends RFTensor { // TODO: Should we be using ArrowBuf here directly, since Arrow Tensor can not have pages? - def rows = shape(1) - def cols = shape(2) - val cellType = DoubleCellType // TODO: Figure out how to work this crazy thing // def copy(implicit alloc: BufferAllocator) = { @@ -189,8 +189,8 @@ case class ArrowTensor(val vector: Float8Vector, val shape: Seq[Int]) extends Ce object ArrowTensor { import org.apache.spark.sql.rf.TensorUDT._ - implicit val arrowTensorEncoder: ExpressionEncoder[ArrowTensor] = - CatalystSerializerEncoder[ArrowTensor](true) + implicit val arrowTensorEncoder: ExpressionEncoder[RFTensor] = + CatalystSerializerEncoder[RFTensor](true) val allocator = new RootAllocator(Long.MaxValue) diff --git a/core/src/main/scala/geotrellis/raster/BufferedTensor.scala b/core/src/main/scala/geotrellis/raster/BufferedTensor.scala index 4a4f51607..39b7c6721 100644 --- a/core/src/main/scala/geotrellis/raster/BufferedTensor.scala +++ b/core/src/main/scala/geotrellis/raster/BufferedTensor.scala @@ -26,7 +26,7 @@ case class BufferedTensor( val bufferRows: Int, val bufferCols: Int, val extent: Option[Extent] -) extends CellGrid { +) extends RFTensor { val cellType = DoubleCellType diff --git a/core/src/main/scala/org/apache/spark/sql/rf/TensorUDT.scala b/core/src/main/scala/org/apache/spark/sql/rf/TensorUDT.scala index adf7512c3..01ffbc86c 100644 --- a/core/src/main/scala/org/apache/spark/sql/rf/TensorUDT.scala +++ b/core/src/main/scala/org/apache/spark/sql/rf/TensorUDT.scala @@ -26,31 +26,32 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.{DataType, _} import org.locationtech.rasterframes.encoders.CatalystSerializer import org.locationtech.rasterframes.encoders.CatalystSerializer._ -import org.locationtech.rasterframes.model.{Cells, TileDataContext} -import org.locationtech.rasterframes.ref.RasterRef.RasterRefTile -import org.locationtech.rasterframes.tiles.InternalRowTile +import org.locationtech.rasterframes.model.{Voxels, TensorDataContext} +import org.locationtech.rasterframes.ref.DeferredTensorRef +//import org.locationtech.rasterframes.ref.RasterRef.RasterRefTile +import org.locationtech.rasterframes.tensors.{RFTensor, InternalRowTensor} @SQLUserDefinedType(udt = classOf[TensorUDT]) -class TensorUDT extends UserDefinedType[ArrowTensor] { +class TensorUDT extends UserDefinedType[RFTensor] { import TensorUDT._ override def typeName = TensorUDT.typeName override def pyUDT: String = "pyrasterframes.rf_types.TensorUDT" - def userClass: Class[ArrowTensor] = classOf[ArrowTensor] + def userClass: Class[RFTensor] = classOf[RFTensor] - def sqlType: StructType = schemaOf[ArrowTensor] + def sqlType: StructType = schemaOf[RFTensor] - override def serialize(obj: ArrowTensor): InternalRow = + override def serialize(obj: RFTensor): InternalRow = Option(obj) .map(_.toInternalRow) .orNull - override def deserialize(datum: Any): ArrowTensor = + override def deserialize(datum: Any): RFTensor = Option(datum) .collect { - case ir: InternalRow ⇒ ir.to[ArrowTensor] + case ir: InternalRow ⇒ ir.to[RFTensor] } .orNull @@ -61,23 +62,34 @@ class TensorUDT extends UserDefinedType[ArrowTensor] { } case object TensorUDT { - UDTRegistration.register(classOf[ArrowTensor].getName, classOf[TensorUDT].getName) + UDTRegistration.register(classOf[RFTensor].getName, classOf[TensorUDT].getName) final val typeName: String = "tensor" - implicit def tensorSerializer: CatalystSerializer[ArrowTensor] = new CatalystSerializer[ArrowTensor] { + implicit def tensorSerializer: CatalystSerializer[RFTensor] = new CatalystSerializer[RFTensor] { override val schema: StructType = StructType(Seq( - StructField("arrow_tensor", BinaryType, true) + StructField("tensor_context", schemaOf[TensorDataContext], true), + StructField("tensor_data", schemaOf[Voxels], false) )) - override def to[R](t: ArrowTensor, io: CatalystIO[R]): R = io.create { - t.toArrowBytes() - } - - override def from[R](row: R, io: CatalystIO[R]): ArrowTensor = { - val bytes = io.getByteArray(row, 0) - ArrowTensor.fromArrowMessage(bytes) + override def to[R](t: RFTensor, io: CatalystIO[R]): R = io.create( + t match { + case _: DeferredTensorRef => null + case o => io.to(TensorDataContext(o)) + }, + io.to(Voxels(t)) + ) + + override def from[R](row: R, io: CatalystIO[R]): RFTensor = { + val voxels = io.get[Voxels](row, 1) + + row match { + case ir: InternalRow if !voxels.isRef ⇒ new InternalRowTensor(ir) + case _ ⇒ + val ctx = io.get[TensorDataContext](row, 0) + voxels.toTensor + } } } } diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/BinaryLocalRasterOp.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/BinaryLocalRasterOp.scala index 3da1e9262..1f5f6d55e 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/BinaryLocalRasterOp.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/BinaryLocalRasterOp.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, import org.apache.spark.sql.catalyst.expressions.BinaryExpression import org.apache.spark.sql.rf.{TileUDT, TensorUDT} import org.apache.spark.sql.types.DataType +import org.locationtech.rasterframes.tensors.RFTensor import org.locationtech.rasterframes.encoders.CatalystSerializer._ import org.locationtech.rasterframes.expressions.DynamicExtractors._ import org.slf4j.LoggerFactory @@ -92,9 +93,9 @@ trait BinaryLocalRasterOp extends BinaryExpression { } (context, isTensor) match { - case (Some(ctx), true) ⇒ result.asInstanceOf[ArrowTensor].toInternalRow + case (Some(ctx), true) ⇒ result.asInstanceOf[RFTensor].toInternalRow case (Some(ctx), false) ⇒ ctx.toProjectRasterTile(result.asInstanceOf[Tile]).toInternalRow - case (None, true) ⇒ result.asInstanceOf[ArrowTensor].toInternalRow + case (None, true) ⇒ result.asInstanceOf[RFTensor].toInternalRow case (None, false) ⇒ result.asInstanceOf[Tile].toInternalRow } } diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/DynamicExtractors.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/DynamicExtractors.scala index ce6c2889e..883bcc829 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/DynamicExtractors.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/DynamicExtractors.scala @@ -34,6 +34,7 @@ import org.locationtech.jts.geom.{Envelope, Point} import org.locationtech.rasterframes.encoders.CatalystSerializer._ import org.locationtech.rasterframes.model.{LazyCRS, TileContext} import org.locationtech.rasterframes.ref.{ProjectedRasterLike, RasterRef, RasterSource} +import org.locationtech.rasterframes.tensors.RFTensor import org.locationtech.rasterframes.tiles.ProjectedRasterTile private[rasterframes] @@ -95,7 +96,7 @@ object DynamicExtractors { case _: TileUDT => (row: InternalRow) => row.to[Tile](TileUDT.tileSerializer) case _: TensorUDT => - (row: InternalRow) => row.to[ArrowTensor](TensorUDT.tensorSerializer) + (row: InternalRow) => row.to[RFTensor](TensorUDT.tensorSerializer) case _: BufferedTensorUDT => (row: InternalRow) => row.to[BufferedTensor](BufferedTensorUDT.bufferedTensorSerializer) case _: RasterSourceUDT => diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/RasterSourcesToTensorRefs.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/RasterSourcesToTensorRefs.scala index 30d4669ae..42ed65c8d 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/RasterSourcesToTensorRefs.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/RasterSourcesToTensorRefs.scala @@ -47,7 +47,7 @@ import scala.util.control.NonFatal * @since 9/6/18 */ // BUFFER HERE -case class RasterSourcesToTensorRefs(child: Expression, subtileDims: Option[TileDimensions] = None) extends UnaryExpression +case class RasterSourcesToTensorRefs(child: Expression, subtileDims: Option[TileDimensions] = None, bufferPixels: Int = 0) extends UnaryExpression with Generator with CodegenFallback with ExpectsInputTypes { import TensorRef._ import org.locationtech.rasterframes.expressions.transformers.PatternToRasterSources._ @@ -78,14 +78,11 @@ case class RasterSourcesToTensorRefs(child: Expression, subtileDims: Option[Tile val sampleRS = rss.head.source - val maybeSubs = subtileDims.map { dims => - val subGB = sampleRS.layoutBounds(dims) - subGB.map(gb => (gb, sampleRS.rasterExtent.extentFor(gb, clamp = true))) - } + val maybeSubs = subtileDims.map { dims => sampleRS.layoutBounds(dims) } val trefs = maybeSubs.map { subs => - subs.map { case (gb, extent) => TensorRef(rss, Some(extent), Some(gb)) } - }.getOrElse(Seq(TensorRef(rss, None, None))) + subs.map { gb => TensorRef(rss, Some(gb), bufferPixels) } + }.getOrElse(Seq(TensorRef(rss, None, bufferPixels))) trefs.map{ tref => InternalRow(tref.toInternalRow) } } diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/TensorRefToTensor.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/TensorRefToTensor.scala index c23edc080..90fa9b7e1 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/TensorRefToTensor.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/TensorRefToTensor.scala @@ -55,9 +55,7 @@ case class TensorRefToTensor(child: Expression, bufferPixels: Int) extends Unary override protected def nullSafeEval(input: Any): Any = { implicit val ser = ProjectedBufferedTensor.serializer val ref = row(input).to[TensorRef] - val realized = ref.realizedTensor(bufferPixels) - - realized.toInternalRow + ref.tensor.toInternalRow } } diff --git a/core/src/main/scala/org/locationtech/rasterframes/model/TensorDataContext.scala b/core/src/main/scala/org/locationtech/rasterframes/model/TensorDataContext.scala new file mode 100644 index 000000000..a0b9ca3ee --- /dev/null +++ b/core/src/main/scala/org/locationtech/rasterframes/model/TensorDataContext.scala @@ -0,0 +1,65 @@ +/* + * This software is licensed under the Apache 2 license, quoted below. + * + * Copyright 2019 Astraea, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * [http://www.apache.org/licenses/LICENSE-2.0] + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + * + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.locationtech.rasterframes.model + +import org.locationtech.rasterframes.encoders.CatalystSerializer._ +import geotrellis.raster.{CellType, Tile, ArrowTensor} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.types.{StructField, StructType, IntegerType} +import org.locationtech.rasterframes.encoders.{CatalystSerializer, CatalystSerializerEncoder} + +/** Encapsulates all information about a tile aside from actual cell values. */ +case class TensorDataContext(depth: Int, rows: Int, cols: Int, bufferPixels: Int) +object TensorDataContext { + + /** Extracts the TensorDataContext from a Tile. */ + def apply(t: RFTensor): TensorDataContext = { + require(t.depth <= Short.MaxValue, s"RasterFrames doesn't support tiles of size ${t.depth}") + require(t.cols <= Short.MaxValue, s"RasterFrames doesn't support tiles of size ${t.cols}") + require(t.rows <= Short.MaxValue, s"RasterFrames doesn't support tiles of size ${t.rows}") + TensorDataContext(t.shape(0), t.shape(1), t.shape(2), bufferPixels) + } + + implicit val serializer: CatalystSerializer[TensorDataContext] = new CatalystSerializer[TensorDataContext] { + override val schema: StructType = StructType(Seq( + StructField("depth", IntegerType, false), + StructField("cols", IntegerType, false), + StructField("rows", IntegerType, false), + StructField("bufferPixels", IntegerType, false) + )) + + override protected def to[R](t: TensorDataContext, io: CatalystIO[R]): R = io.create( + t.depth, + t.rows, + t.cols, + t.bufferPixels + ) + override protected def from[R](t: R, io: CatalystIO[R]): TensorDataContext = TensorDataContext( + io.getInt(t, 0), + io.getInt(t, 1), + io.getInt(t, 2), + io.getInt(t, 3) + ) + } + + implicit def encoder: ExpressionEncoder[TensorDataContext] = CatalystSerializerEncoder[TensorDataContext]() +} diff --git a/core/src/main/scala/org/locationtech/rasterframes/model/Voxels.scala b/core/src/main/scala/org/locationtech/rasterframes/model/Voxels.scala new file mode 100644 index 000000000..f6609bec0 --- /dev/null +++ b/core/src/main/scala/org/locationtech/rasterframes/model/Voxels.scala @@ -0,0 +1,81 @@ +/* + * This software is licensed under the Apache 2 license, quoted below. + * + * Copyright 2019 Astraea, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * [http://www.apache.org/licenses/LICENSE-2.0] + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + * + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.locationtech.rasterframes.model + +import geotrellis.raster._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.types.{BinaryType, StructField, StructType} +import org.locationtech.rasterframes +import org.locationtech.rasterframes.encoders.CatalystSerializer._ +import org.locationtech.rasterframes.encoders.{CatalystSerializer, CatalystSerializerEncoder} +import org.locationtech.rasterframes.ref.{TensorRef, DeferredTensorRef} +import org.locationtech.rasterframes.tensors.RFTensor + +/** Represents the union of binary cell datas or a reference to the data.*/ +case class Voxels(data: Either[Array[Byte], TensorRef]) { + def isRef: Boolean = data.isRight + + /** Convert voxels into either a DeferredTensorRef or an ArrowTensor. */ + def toTensor(ctx: TensorDataContext): RFTensor = data.fold( + bytes => { + val nakedTensor = ArrowTensor.fromArrowMessage(bytes) + BufferedTensor(ArrowTensor.fromArrowMessage, ctx.bufferPixels) + }, + ref => DeferredTensorRef(ref) + ) +} + +object Voxels { + /** Extracts the Voxels from a Tensor. */ + def apply(t: RFTensor): Voxels = { + t match { + case arrowTensor: ArrowTensor => + Voxels(Left(arrowTensor.toArrowBytes())) + case ref: DeferredTensorRef => + Voxels(Right(ref.deferred)) + case const: ConstantTile => + throw new IllegalArgumentException + } + } + + implicit def voxelsSerializer: CatalystSerializer[Voxels] = new CatalystSerializer[Voxels] { + override val schema: StructType = + StructType( + Seq( + StructField("voxels", BinaryType, true), + StructField("ref", schemaOf[TensorRef], true) + )) + override protected def to[R](t: Voxels, io: CatalystSerializer.CatalystIO[R]): R = io.create( + t.data.left.getOrElse(null), + t.data.right.map(tr => io.to(tr)).right.getOrElse(null) + ) + override protected def from[R](t: R, io: CatalystSerializer.CatalystIO[R]): Voxels = { + if (!io.isNullAt(t, 0)) + Voxels(Left(io.getByteArray(t, 0))) + else if (!io.isNullAt(t, 1)) + Voxels(Right(io.get[TensorRef](t, 1))) + else throw new IllegalArgumentException("must be either arrow tensor data or a ref, but not null") + } + } + + implicit def encoder: ExpressionEncoder[Voxels] = CatalystSerializerEncoder[Voxels]() +} diff --git a/core/src/main/scala/org/locationtech/rasterframes/ref/DeferredTensorRef.scala b/core/src/main/scala/org/locationtech/rasterframes/ref/DeferredTensorRef.scala new file mode 100644 index 000000000..bb33acf3a --- /dev/null +++ b/core/src/main/scala/org/locationtech/rasterframes/ref/DeferredTensorRef.scala @@ -0,0 +1,42 @@ +/* + * This software is licensed under the Apache 2 license, quoted below. + * + * Copyright 2018 Astraea, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * [http://www.apache.org/licenses/LICENSE-2.0] + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + * + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.locationtech.rasterframes.ref + +import com.typesafe.scalalogging.LazyLogging +import geotrellis.proj4.CRS +import geotrellis.raster._ +import geotrellis.vector.{Extent, ProjectedExtent} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.rf.RasterSourceUDT +import org.apache.spark.sql.types.{IntegerType, StructField, StructType, ArrayType} +import org.apache.spark.sql.Encoder +import org.locationtech.rasterframes._ +import org.locationtech.rasterframes.encoders.CatalystSerializer.{CatalystIO, _} +import org.locationtech.rasterframes.encoders.{CatalystSerializer, CatalystSerializerEncoder} +import org.locationtech.rasterframes.ref.RasterSource._ +import org.locationtech.rasterframes.tensors.{ProjectedBufferedTensor, RFTensor} +import org.locationtech.rasterframes.expressions.transformers.PatternToRasterSources._ + + +case class DeferredTensorRef(deferred: TensorRef) extends RFTensor { + def shape = Seq(deferred.sources.length, deferred.rows, deferred.cols) +} diff --git a/core/src/main/scala/org/locationtech/rasterframes/ref/TensorRef.scala b/core/src/main/scala/org/locationtech/rasterframes/ref/TensorRef.scala index ed05a1d28..738fd3f82 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/ref/TensorRef.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/ref/TensorRef.scala @@ -42,15 +42,13 @@ import org.locationtech.rasterframes.expressions.transformers.PatternToRasterSou * * @since 8/21/18 */ -case class TensorRef(sources: Seq[RasterSourceWithBand], subextent: Option[Extent], subgrid: Option[GridBounds]) +case class TensorRef(sources: Seq[RasterSourceWithBand], subgrid: Option[GridBounds], bufferPixels: Int = 0) extends ProjectedRasterLike { def sample = sources.head.source def crs: CRS = sample.crs def cols: Int = grid.width def rows: Int = grid.height def cellType: CellType = sample.cellType - //def tile: ProjectedRasterTile = RasterRefTile(this) - protected lazy val grid: GridBounds = subgrid.getOrElse(sample.rasterExtent.gridBoundsFor(extent, true)) @@ -58,20 +56,23 @@ case class TensorRef(sources: Seq[RasterSourceWithBand], subextent: Option[Exten // This should correspond to the gridded region to which this tensor reference refers lazy val extent: Extent = RasterExtent(sample.extent, sample.cellSize).extentFor(grid) - lazy val realizedTensor: ArrowTensor = { - //RasterRef.log.trace(s"Fetching $extent ($grid) from band $bandIndex of $sample") - val tiles = sources.map({ case RasterSourceWithBand(rs, band) => - rs.read(grid, Seq(band)).tile.band(0) - }) - ArrowTensor.stackTiles(tiles) - } + // lazy val realizedTensor: ArrowTensor = { + // //RasterRef.log.trace(s"Fetching $extent ($grid) from band $bandIndex of $sample") + // val tiles = sources.map({ case RasterSourceWithBand(rs, band) => + // rs.read(grid, Seq(band)).tile.band(0) + // }) + // ArrowTensor.stackTiles(tiles) + // } + + def tensor: DeferredTensorRef = DeferredTensorRef(this) - def realizedTensor(bufferPixels: Int): ProjectedBufferedTensor = { + lazy val realizedTensor: ProjectedBufferedTensor = { //RasterRef.log.trace(s"Fetching $extent ($grid) from band $bandIndex of $sample") val bufferedGrid = grid.buffer(bufferPixels) val tiles = sources.map({ case RasterSourceWithBand(rs, band) => val tile = rs.read(bufferedGrid, Seq(band)).tile.band(0) + //val tile = RasterRef(rs, band, bufferedGrid, bufferPixels)// .read(bufferedGrid, Seq(band)).tile.band(band) val rsBounds = GridBounds(0, 0, rs.cols - 1, rs.rows - 1) @@ -144,16 +145,13 @@ object TensorRef extends LazyLogging { override def to[R](t: TensorRef, io: CatalystIO[R]): R = io.create( io.toSeq(t.sources), - t.subextent.map(io.to[Extent]).orNull, t.subgrid.map(io.to[GridBounds]).orNull ) override def from[R](row: R, io: CatalystIO[R]): TensorRef = TensorRef( io.getSeq[RasterSourceWithBand](row, 0), - if (io.isNullAt(row, 1)) None - else Option(io.get[Extent](row, 1)), if (io.isNullAt(row, 2)) None - else Option(io.get[GridBounds](row, 2)) + else Option(io.get[GridBounds](row, 1)) ) } diff --git a/core/src/main/scala/org/locationtech/rasterframes/tensors/InternalRowTensor.scala b/core/src/main/scala/org/locationtech/rasterframes/tensors/InternalRowTensor.scala new file mode 100644 index 000000000..608ea4fe9 --- /dev/null +++ b/core/src/main/scala/org/locationtech/rasterframes/tensors/InternalRowTensor.scala @@ -0,0 +1,92 @@ +/* + * This software is licensed under the Apache 2 license, quoted below. + * + * Copyright 2018 Astraea, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * [http://www.apache.org/licenses/LICENSE-2.0] + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + * + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.locationtech.rasterframes.tensors + +import java.nio.ByteBuffer + +import org.locationtech.rasterframes.encoders.CatalystSerializer.CatalystIO +import geotrellis.raster._ +import org.apache.spark.sql.catalyst.InternalRow +import org.locationtech.rasterframes.model.{Voxels, TensorDataContext} + +/** + * Wrapper around a `Tile` encoded in a Catalyst `InternalRow`, for the purpose + * of providing compatible semantics over common operations. + * + * @since 11/29/17 + */ +class InternalRowTensor(val mem: InternalRow) extends RFTensor { + import InternalRowTensor._ + + //override def toArrayTile(): ArrayTile = realizedTile.toArrayTile() + + // TODO: We want to reimplement relevant delegated methods so that they read directly from tungsten storage + lazy val realizedTensor: RFTensor = voxels.toTensor + + protected override def delegate: RFTensor = realizedTensor + + private def cellContext: TensorDataContext = + CatalystIO[InternalRow].get[TensorDataContext](mem, 0) + + private def voxels: Voxels = CatalystIO[InternalRow].get[Voxels](mem, 1) + + override def depth: Int = cellContext.depth + + /** Retrieve the number of columns from the internal encoding. */ + override def cols: Int = cellContext.cols + + /** Retrieve the number of rows from the internal encoding. */ + override def rows: Int = cellContext.rows + + /** Get the internally encoded tile data cells. */ + override lazy val toBytes: Array[Byte] = { + voxels.data.left + .getOrElse(throw new IllegalStateException( + "Expected tile cell bytes, but received RasterRef instead: " + voxels.data.right.get) + ) + } + + // private lazy val toByteBuffer: ByteBuffer = { + // val data = toBytes + // if(data.length < cols * rows && cellType.name != "bool") { + // // Handling constant tiles like this is inefficient and ugly. All the edge + // // cases associated with them create too much undue complexity for + // // something that's unlikely to be + // // used much in production to warrant handling them specially. + // // If a more efficient handling is necessary, consider a flag in + // // the UDT struct. + // ByteBuffer.wrap(toArrayTile().toBytes()) + // } else ByteBuffer.wrap(data) + // } + + // /** Reads the cell value at the given index as an Int. */ + // def apply(i: Int): Int = cellReader.apply(i) + + // /** Reads the cell value at the given index as a Double. */ + // def applyDouble(i: Int): Double = cellReader.applyDouble(i) + + def copy = new InternalRowTensor(mem.copy) + + //override def toString: String = ShowableTile.show(this) +} + +object InternalRowTensor {} diff --git a/core/src/main/scala/org/locationtech/rasterframes/tensors/RFTensor.scala b/core/src/main/scala/org/locationtech/rasterframes/tensors/RFTensor.scala new file mode 100644 index 000000000..d6d18bf93 --- /dev/null +++ b/core/src/main/scala/org/locationtech/rasterframes/tensors/RFTensor.scala @@ -0,0 +1,45 @@ +/* + * This software is licensed under the Apache 2 license, quoted below. + * + * Copyright 2018 Astraea, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * [http://www.apache.org/licenses/LICENSE-2.0] + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + * + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.locationtech.rasterframes.tensors + +import com.typesafe.scalalogging.LazyLogging +import geotrellis.proj4.CRS +import geotrellis.raster._ +import geotrellis.vector.{Extent, ProjectedExtent} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.rf.RasterSourceUDT +import org.apache.spark.sql.types.{IntegerType, StructField, StructType, ArrayType} +import org.apache.spark.sql.Encoder +import org.locationtech.rasterframes._ +import org.locationtech.rasterframes.encoders.CatalystSerializer.{CatalystIO, _} +import org.locationtech.rasterframes.encoders.{CatalystSerializer, CatalystSerializerEncoder} +import org.locationtech.rasterframes.ref.RasterSource._ +import org.locationtech.rasterframes.expressions.transformers.PatternToRasterSources._ + + +abstract class RFTensor extends CellGrid { + def shape: Seq[Int] + def depth: Int = shape(0) + def rows: Int = shape(1) + def cols: Int = shape(2) + def cellType: CellType = DoubleCellType +} \ No newline at end of file